赞
踩
可以在这个网址下载代码,里面有很多现有的,比如文字生成图像,图像生成图像
https://github.com/huggingface/diffusers
因为扩散模型训练起来很慢,不一定每个人都可以训练出来,所以他们提供了现成的模型,可以直接调用,就很爽。下面这个网址就是所有的模型汇总的,不仅仅局限于扩散模型。下面我来演示在服务器上用自己的数据训练模型。
1.下载模型源码
可以直接进入第一个链接去下载,也可以在服务器上输入如下命令:
- git clone https://github.com/huggingface/diffusers
- cd diffusers
- pip install .
下包之前,最后自己手动下载torch,指定版本,不然就是最新版。
- #这两个是不同版本的torch,对应不同版本的cuda
- pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
-
- pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 -f https://download.pytorch.org/whl/torch_stable.html
这里还不能下载其他包,按照你要做的扩散模型来下。
2.下载安装包
以图像生成图像为例:
进入example,unconditional_image_generation,里面只有三个文件,就是图像中的后三个。前面那两个是我自己建的。
下载这里的requirements中的包,并在unconditional_image_generation中导入自己的数据集。如果没有自己的数据集,可以用该网站自带的。
数据集要求格式如下:
- data_dir/xxx.png
- data_dir/xxy.png
- data_dir/[...]/xxz.png
3.修改参数
进入train_unconditional.py,找到main函数,这些看自己情况修改。参数有很多。
- --train_data_dir="imgs" \#数据集
- --resolution=64 \数据集的size大小,代码会把你的数据集里所有的图像压缩成这个大小,
- #而且也是生成图像的大小
- --output_dir="ddpm-ema-flowers-64" \模型位置
- --train_batch_size=16 \
- --num_epochs=100 \
4.配置训练环境
这个扩散模型还需要额外的修改下环境配置,如下所示
accelerate config
你可以照我这么来弄,也可以按照选项来。
5.训练
这个起码训练10h+,弄个nohup。
nohup accelerate launch train_unconditional.py > ./output.log 2>&1 &
6.使用model
新建一个generate.py,改一下model_id,就可以用了
- # !pip install diffusers
- from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
- import os
- model_id = "ddpm-model-64"
- #生成的图像放的位置
- img_path = 'results'+'/'+model_id+'-img'
- if not os.path.exists(img_path): os.mkdir(img_path)
- device = "cuda"
-
- # load model and scheduler
- ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference
- ddpm.to(device)
- for i in range(100):
- # run pipeline in inference (sample random noise and denoise)
- image = ddpm().images[0]
- # save image
- #不修改格式
- #image.save(os.path.join(img_path,f'{i}.png'))
- #改成单通道
- image.convert('L').save(os.path.join(img_path,f'{i}.png'))
- #看看跑到哪里了
- if i%10==0:print(f"i={i}")

Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。