赞
踩
StableSR模型出自论文《Exploiting Diffusion Prior for Real-World Image Super-Resolution》,使用扩散模型做自然界真实影像的超分辨率。其数据增强部分参考Real-ESRGAN工程,因此该模型也可以算是盲超分领域。文章的具体原理可以看论文详细了解,本篇主要介绍模型的训练过程。
我是用anaconda新创建了一个虚拟环境,然后根据作者的环境要求进行了配置,过程挺顺利的,没遇到啥问题,具体如下:
environment.yaml
- # git clone this repository
- git clone https://github.com/IceClear/StableSR.git
- cd StableSR
-
- # Create a conda environment and activate it
- conda env create --file environment.yaml
- conda activate stablesr
-
- # Install xformers
- conda install xformers -c xformers/label/dev
-
- # Install taming & clip
- pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
- pip install -e .
需要从HuggingFace上下载预训练的Stable Diffusion模型,下载路径如下:
https://huggingface.co/stabilityai/stable-diffusion-2-1-base
点击画圈部分下载,模型比较大,有5个多G
我做的是遥感影像的超分,因此就是将高分辨率的遥感影像切分成512x512的patch,放到一个文件夹里。修改配置文件v2-finetune_text_T_512.yaml的gt_path,设置为文件夹路径。
修改配置文件v2-finetune_text_T_512.yaml,主要配置ckpt_path的路径,修改为下载的Stable Diffusion预训练模型路径。
其他参数基本不用修改,我是batch_size默认用的6, 3090的卡显存不够,可以修改batch_size和queue_size,适当调整改小一些。
主要训练的事Time-aware encoder模型,训练脚本如下:
python main.py --train --base configs/stableSRNew/v2-finetune_text_T_512.yaml --gpus GPU_ID, --name NAME --scale_lr False
如果有多块GPU,可以设置GPU_ID(0,1),--name参数为训练时产生文件的存储文件夹名字,自己设置,可以加一个参数--no-test,这样在训练时不会进行验证,因为我没有准备验证集。
整个训练过程是比较耗时的,配置文件设置的迭代次数是80w次,训练过程中会在logs文件夹下存储可视化结果,自己可根据samples生成的结果与gt对比,判断模型的训练程度。作者给的工程也配置了wandb,可以在线看训练过程统计,包含loss之类的,我因为是后台训练,就直接把这个关掉了,在main.py文件中设置offline参数为True。
阶段二主要是VQGAN模型的训练,用使用的配置文件是autoencoder_kl_64x64x4_resi.yaml。用第一阶段训练模型的gts数据,用Real-ESRGAN的数据增强方式生成inputs数据,即降质的数据。需要将原来工程中数据增强的代码摘出来单独写一个脚本。如有需要,可以找我提供。
然后需要使用第一阶段训练得到的模型last.ckpt,生成latent。需要用到Stable-SR工程中scripts文件夹中的sr_val_ddpm_text_T_vqganfin_old.py脚本,修改脚本中部分内容,
作者用的还是最初的预训练模型,我用的第一阶段训练的last.ckpt
之前脚本时一次性读入测试图片缓存,训练图像数据量大内存会爆,所以最好改成单张图像读取测试。修改部分如下所示:
配置文件路径
注释掉原来的一次性图像读取模式
改成一次读取一张图像测试
latnet保存位置
保存成npy文件
python scripts/sr_val_ddpm_text_T_vqganfin_old.py --config configs/stableSRNew/v2-finetune_text_T_512.yaml --ckpt pretrain/last.ckpt --vqgan_ckpt pretrain/last.ckpt --init-img CFW_trainingdata/inputs/ --outdir CFW_trainingdata/samples --ddpm_steps 200 --dec_w 0.0 --colorfix_type adain
注意设置参数--dec_w 0.0,因为这时候VQGAN模型还没训练好。生成好之后依次将gts,inputs,samples,latents放到一个文件夹里。
第二阶段模型训练脚本
python main.py --train --base configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml --gpus GPU_ID, --name NAME --scale_lr False --no-test
迭代合适的次数,看可视化结果判断,得到训练好的VQGAN模型,结合第一阶段模型last.ckpt,使用sr_val_ddpm_text_T_vqganfin_old.py脚本进行推理,在视觉效果上看来,比Real-ESRGAN模型的细节还原确实更好。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。