当前位置:   article > 正文

初学者体验扩散模型_扩散模型跑一次多久

扩散模型跑一次多久

可以在这个网址下载代码,里面有很多现有的,比如文字生成图像,图像生成图像

https://github.com/huggingface/diffusers

因为扩散模型训练起来很慢,不一定每个人都可以训练出来,所以他们提供了现成的模型,可以直接调用,就很爽。下面这个网址就是所有的模型汇总的,不仅仅局限于扩散模型。下面我来演示在服务器上用自己的数据训练模型。

Models - Hugging Face

1.下载模型源码

可以直接进入第一个链接去下载,也可以在服务器上输入如下命令:

  1. git clone https://github.com/huggingface/diffusers
  2. cd diffusers
  3. pip install .

下包之前,最后自己手动下载torch,指定版本,不然就是最新版。

  1. #这两个是不同版本的torch,对应不同版本的cuda
  2. pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
  3. pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 -f https://download.pytorch.org/whl/torch_stable.html

这里还不能下载其他包,按照你要做的扩散模型来下。

2.下载安装包

以图像生成图像为例:

进入example,unconditional_image_generation,里面只有三个文件,就是图像中的后三个。前面那两个是我自己建的。

下载这里的requirements中的包,并在unconditional_image_generation中导入自己的数据集。如果没有自己的数据集,可以用该网站自带的。

数据集要求格式如下:

  1. data_dir/xxx.png
  2. data_dir/xxy.png
  3. data_dir/[...]/xxz.png

3.修改参数

进入train_unconditional.py,找到main函数,这些看自己情况修改。参数有很多。

  1. --train_data_dir="imgs" \#数据集
  2. --resolution=64 \数据集的size大小,代码会把你的数据集里所有的图像压缩成这个大小,
  3. #而且也是生成图像的大小
  4. --output_dir="ddpm-ema-flowers-64" \模型位置
  5. --train_batch_size=16 \
  6. --num_epochs=100 \

4.配置训练环境

这个扩散模型还需要额外的修改下环境配置,如下所示

accelerate config

你可以照我这么来弄,也可以按照选项来。

5.训练 

这个起码训练10h+,弄个nohup。

nohup accelerate launch train_unconditional.py > ./output.log 2>&1 &

6.使用model

新建一个generate.py,改一下model_id,就可以用了

  1. # !pip install diffusers
  2. from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
  3. import os
  4. model_id = "ddpm-model-64"
  5. #生成的图像放的位置
  6. img_path = 'results'+'/'+model_id+'-img'
  7. if not os.path.exists(img_path): os.mkdir(img_path)
  8. device = "cuda"
  9. # load model and scheduler
  10. ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference
  11. ddpm.to(device)
  12. for i in range(100):
  13. # run pipeline in inference (sample random noise and denoise)
  14. image = ddpm().images[0]
  15. # save image
  16. #不修改格式
  17. #image.save(os.path.join(img_path,f'{i}.png'))
  18. #改成单通道
  19. image.convert('L').save(os.path.join(img_path,f'{i}.png'))
  20. #看看跑到哪里了
  21. if i%10==0:print(f"i={i}")

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号