扩散模型入门(DDPM论文复现)
记录一些复现扩散模型的一个小demo,供参考学习。框架是pytorch,没有用很多的包,写动画用了一个imageio的,相信安装他们并不会很困难
前置需求:
有显卡的机子
包需求:
pytorch(装好了cuda)
imageio
数据集准备:
celebA
说明,仅是学习流程需求,注意力模块没写,有需求可以自行改,模型没有弄的很大,太大了小机子吃不消,勉强看看结果就行。
下面介绍怎么用,如果你不需要训练模型,只想看看结果,可以作者本人训练的模型参数,放到checkpoints目录下。
0.下载项目代码
https://github.com/minatoyukinaa/smallDDPM-PyTorch-
由于git使用不熟练,空白文件夹似乎无法正确创建,要手动创建checkpoints,data这两个文件夹,文件目录如图。
1.下载数据集
通过网盘分享的文件:CelebA_nocrop.zip
链接: https://pan.baidu.com/s/1o4UMLYbsV2kw1WejPq8Uyw?pwd=iay6 提取码: iay6
然后把这个数据集解压到项目的data目录下,解压完毕目录这样。
2.训练模型/下载模型
找到task1Train文件,执行里面的代码,跑一个epoch以上,在checkpoints目录下生成checkpoint.pth文件.
如果你训练不了,这里提供一个运行了大约13个epoch的模型参数文件,
通过网盘分享的文件:checkpoint.pth
链接: https://pan.baidu.com/s/1fK2ms3D3Wxr7bF8j3xzaug?pwd=mgfj 提取码: mgfj
3.生成一些图片
找到valid文件,运行它,并且生成一个mp4文件,里面记录了怎么从高斯噪声恢复的视频。
以下是作者自己生成的视频