当前位置:   article > 正文

使用DDPM(扩散模型)训练自己的数据集实现数据集扩容pytorch_ddpm-model-128

ddpm-model-128

一、简介

DDPM扩散模型包括两个过程:前向过程(forward process)和反向过程(reverse process),其中前向过程又称为扩散过程(diffusion process)。无论是前向过程还是反向过程都是一个参数化的马尔可夫链(Markov chain),其中反向过程可以用来生成数据,可通过变分推断来进行建模和求解。在DDPM中,通过连续添加高斯噪声来破坏训练数据,然后通过反转这个噪声过程,来学习恢复数据。测试时,可以使用DDPM将随机采样的噪声传入模型中,通过学习去噪过程来生成数据。

二、实战

首先,下载安装指定的python包

pip install diffusers

我使用的数据集为柑橘病害叶片数据集(未开源),三种类型,分别为黄龙病、缺镁、正常三种性状。

我使用数据集格式:(数据集不需要划分为train、val、test)

  1. dataset_orgin
  2. --Huanglong_disease
  3. ----0.jpg
  4. ----1.jpg
  5. ----~~~~~
  6. --Magnesium_deficiency
  7. ----0.jpg
  8. ----1.jpg
  9. ----~~~~~
  10. --Normal
  11. ----0.jpg
  12. ----1.jpg
  13. ----~~~~~
新建python文件,用于训练

train_unconditional.py
  1. import argparse
  2. import inspect
  3. import math
  4. import os
  5. from pathlib import Path
  6. from typing import Optional
  7. import torch
  8. import torch.nn.functional as F
  9. from accelerate import Accelerator
  10. from accelerate.logging import get_logger
  11. from datasets import load_dataset
  12. from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
  13. from diffusers.optimization import get_scheduler
  14. from diffusers.training_utils import EMAModel
  15. from diffusers.utils import check_min_version
  16. from huggingface_hub import HfFolder, Repository, whoami
  17. from torchvision.transforms import (
  18. CenterCrop,
  19. Compose,
  20. InterpolationMode,
  21. Normalize,
  22. RandomHorizontalFlip,
  23. Resize,
  24. ToTensor,
  25. )
  26. from tqdm.auto import tqdm
  27. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
  28. check_min_version("0.10.0.dev0")
  29. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # 设备
  30. logger = get_logger(__name__)
  31. def _extract_into_tensor(arr, timesteps, broadcast_shape):
  32. """
  33. Extract values from a 1-D numpy array for a batch of indices.
  34. :param arr: the 1-D numpy array.
  35. :param timesteps: a tensor of indices into the array to extract.
  36. :param broadcast_shape: a larger shape of K dimensions with the batch
  37. dimension equal to the length of timesteps.
  38. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
  39. """
  40. if not isinstance(arr, torch.Tensor):
  41. arr = torch.from_numpy(arr)
  42. res = arr[timesteps].float().to(timesteps.device)
  43. while len(res.shape) < len(broadcast_shape):
  44. res = res[..., None]
  45. return res.expand(broadcast_shape)
  46. def parse_args():
  47. parser = argparse.ArgumentParser(description="Simple example of a training script.")
  48. parser.add_argument(
  49. "--dataset_name",
  50. type=str,
  51. default=r'D:\pyCharmdata\datasets_orgin\Huanglong_disease',#None
  52. help=(
  53. "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
  54. " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
  55. " or to a folder containing files that HF Datasets can understand."
  56. ),
  57. )
  58. parser.add_argument(
  59. "--dataset_config_name",
  60. type=str,
  61. default=None,#None
  62. help="The config of the Dataset, leave as None if there's only one config.",
  63. )
  64. parser.add_argument(
  65. "--train_data_dir",
  66. type=str,
  67. default=r'D:\pyCharmdata\datasets_orgin\Huanglong_disease',# None
  68. help=(
  69. "A folder containing the training data. Folder contents must follow the structure described in"
  70. " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
  71. " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
  72. ),
  73. )
  74. parser.add_argument(
  75. "--output_dir",
  76. type=str,
  77. default="ddpm-model-128",
  78. help="The output directory where the model predictions and checkpoints will be written.",
  79. )
  80. parser.add_argument("--overwrite_output_dir", action="store_true")
  81. parser.add_argument(
  82. "--cache_dir",
  83. type=str,
  84. default=None,
  85. help="The directory where the downloaded models and datasets will be stored.",
  86. )
  87. parser.add_argument(
  88. "--resolution",
  89. type=int,
  90. default=128,
  91. help=(
  92. "The resolution for input images, all the images in the train/validation dataset will be resized to this"
  93. " resolution"
  94. ),
  95. )
  96. parser.add_argument(
  97. "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
  98. )
  99. parser.add_argument(
  100. "--eval_batch_size", type=int, default=1, help="The number of images to generate for evaluation."
  101. )
  102. parser.add_argument(
  103. "--dataloader_num_workers",
  104. type=int,
  105. default=0,
  106. help=(
  107. "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
  108. " process."
  109. ),
  110. )
  111. parser.add_argument("--num_epochs", type=int, default=300)
  112. parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
  113. parser.add_argument(
  114. "--save_model_epochs", type=int, default=10, help="How often to save the model during training."
  115. )
  116. parser.add_argument(
  117. "--gradient_accumulation_steps",
  118. type=int,
  119. default=1,
  120. help="Number of updates steps to accumulate before performing a backward/update pass.",
  121. )
  122. parser.add_argument(
  123. "--learning_rate",
  124. type=float,
  125. default=1e-4,
  126. help="Initial learning rate (after the potential warmup period) to use.",
  127. )
  128. parser.add_argument(
  129. "--lr_scheduler",
  130. type=str,
  131. default="cosine",
  132. help=(
  133. 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
  134. ' "constant", "constant_with_warmup"]'
  135. ),
  136. )
  137. parser.add_argument(
  138. "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
  139. )
  140. parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")
  141. parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
  142. parser.add_argument(
  143. "--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."
  144. )
  145. parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")
  146. parser.add_argument(
  147. "--use_ema",
  148. action="store_true",
  149. default=True,
  150. help="Whether to use Exponential Moving Average for the final model weights.",
  151. )
  152. parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
  153. parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")
  154. parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.")
  155. parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
  156. parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
  157. parser.add_argument(
  158. "--hub_model_id",
  159. type=str,
  160. default=None,
  161. help="The name of the repository to keep in sync with the local `output_dir`.",
  162. )
  163. parser.add_argument(
  164. "--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
  165. )
  166. parser.add_argument(
  167. "--logging_dir",
  168. type=str,
  169. default="logs",
  170. help=(
  171. "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
  172. " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
  173. ),
  174. )
  175. parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
  176. parser.add_argument(
  177. "--mixed_precision",
  178. type=str,
  179. default="fp16",#"no"
  180. choices=["no", "fp16", "bf16"],
  181. help=(
  182. "Whether to use mixed precision. Choose"
  183. "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
  184. "and an Nvidia Ampere GPU."
  185. ),
  186. )
  187. parser.add_argument(
  188. "--prediction_type",
  189. type=str,
  190. default="epsilon",
  191. choices=["epsilon", "sample"],
  192. help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
  193. )
  194. parser.add_argument("--ddpm_num_steps", type=int, default=1000)
  195. parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
  196. args = parser.parse_args()
  197. env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
  198. if env_local_rank != -1 and env_local_rank != args.local_rank:
  199. args.local_rank = env_local_rank
  200. if args.dataset_name is None and args.train_data_dir is None:
  201. raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
  202. return args
  203. def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
  204. if token is None:
  205. token = HfFolder.get_token()
  206. if organization is None:
  207. username = whoami(token)["name"]
  208. return f"{username}/{model_id}"
  209. else:
  210. return f"{organization}/{model_id}"
  211. def main(args):
  212. logging_dir = os.path.join(args.output_dir, args.logging_dir)
  213. accelerator = Accelerator(
  214. gradient_accumulation_steps=args.gradient_accumulation_steps,
  215. mixed_precision=args.mixed_precision,
  216. log_with="tensorboard",
  217. logging_dir=logging_dir,
  218. )
  219. model = UNet2DModel(
  220. sample_size=args.resolution,
  221. in_channels=3,
  222. out_channels=3,
  223. layers_per_block=2,
  224. block_out_channels=(128, 128, 256, 256, 512, 512),
  225. down_block_types=(
  226. "DownBlock2D",
  227. "DownBlock2D",
  228. "DownBlock2D",
  229. "DownBlock2D",
  230. "AttnDownBlock2D",
  231. "DownBlock2D",
  232. ),
  233. up_block_types=(
  234. "UpBlock2D",
  235. "AttnUpBlock2D",
  236. "UpBlock2D",
  237. "UpBlock2D",
  238. "UpBlock2D",
  239. "UpBlock2D",
  240. ),
  241. )
  242. accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
  243. if accepts_prediction_type:
  244. noise_scheduler = DDPMScheduler(
  245. num_train_timesteps=args.ddpm_num_steps,
  246. beta_schedule=args.ddpm_beta_schedule,
  247. prediction_type=args.prediction_type,
  248. )
  249. else:
  250. noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
  251. optimizer = torch.optim.AdamW(
  252. model.parameters(),
  253. lr=args.learning_rate,
  254. betas=(args.adam_beta1, args.adam_beta2),
  255. weight_decay=args.adam_weight_decay,
  256. eps=args.adam_epsilon,
  257. )
  258. augmentations = Compose(
  259. [
  260. Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
  261. CenterCrop(args.resolution),
  262. RandomHorizontalFlip(),
  263. ToTensor(),
  264. Normalize([0.5], [0.5]),
  265. ]
  266. )
  267. if args.dataset_name is not None:
  268. dataset = load_dataset(
  269. args.dataset_name,
  270. args.dataset_config_name,
  271. cache_dir=args.cache_dir,
  272. split="train",
  273. )
  274. else:
  275. dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
  276. def transforms(examples):
  277. images = [augmentations(image.convert("RGB")) for image in examples["image"]]
  278. return {"input": images}
  279. logger.info(f"Dataset size: {len(dataset)}")
  280. dataset.set_transform(transforms)
  281. train_dataloader = torch.utils.data.DataLoader(
  282. dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
  283. )
  284. lr_scheduler = get_scheduler(
  285. args.lr_scheduler,
  286. optimizer=optimizer,
  287. num_warmup_steps=args.lr_warmup_steps,
  288. num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
  289. )
  290. model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
  291. model, optimizer, train_dataloader, lr_scheduler
  292. )
  293. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
  294. ema_model = EMAModel(
  295. accelerator.unwrap_model(model),
  296. inv_gamma=args.ema_inv_gamma,
  297. power=args.ema_power,
  298. max_value=args.ema_max_decay,
  299. )
  300. # Handle the repository creation
  301. if accelerator.is_main_process:
  302. if args.push_to_hub:
  303. if args.hub_model_id is None:
  304. repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
  305. else:
  306. repo_name = args.hub_model_id
  307. repo = Repository(args.output_dir, clone_from=repo_name)
  308. with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
  309. if "step_*" not in gitignore:
  310. gitignore.write("step_*\n")
  311. if "epoch_*" not in gitignore:
  312. gitignore.write("epoch_*\n")
  313. elif args.output_dir is not None:
  314. os.makedirs(args.output_dir, exist_ok=True)
  315. if accelerator.is_main_process:
  316. run = os.path.split(__file__)[-1].split(".")[0]
  317. accelerator.init_trackers(run)
  318. global_step = 0
  319. for epoch in range(args.num_epochs):
  320. model.train()
  321. progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
  322. progress_bar.set_description(f"Epoch {epoch}")
  323. for step, batch in enumerate(train_dataloader):
  324. clean_images = batch["input"]
  325. # Sample noise that we'll add to the images
  326. noise = torch.randn(clean_images.shape).to(clean_images.device)
  327. bsz = clean_images.shape[0]
  328. # Sample a random timestep for each image
  329. timesteps = torch.randint(
  330. 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device
  331. ).long()
  332. # Add noise to the clean images according to the noise magnitude at each timestep
  333. # (this is the forward diffusion process)
  334. noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
  335. with accelerator.accumulate(model):
  336. # Predict the noise residual
  337. model_output = model(noisy_images, timesteps).sample
  338. if args.prediction_type == "epsilon":
  339. loss = F.mse_loss(model_output, noise) # this could have different weights!
  340. elif args.prediction_type == "sample":
  341. alpha_t = _extract_into_tensor(
  342. noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
  343. )
  344. snr_weights = alpha_t / (1 - alpha_t)
  345. loss = snr_weights * F.mse_loss(
  346. model_output, clean_images, reduction="none"
  347. ) # use SNR weighting from distillation paper
  348. loss = loss.mean()
  349. else:
  350. raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
  351. accelerator.backward(loss)
  352. if accelerator.sync_gradients:
  353. accelerator.clip_grad_norm_(model.parameters(), 1.0)
  354. optimizer.step()
  355. lr_scheduler.step()
  356. if args.use_ema:
  357. ema_model.step(model)
  358. optimizer.zero_grad()
  359. # Checks if the accelerator has performed an optimization step behind the scenes
  360. if accelerator.sync_gradients:
  361. progress_bar.update(1)
  362. global_step += 1
  363. logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
  364. if args.use_ema:
  365. logs["ema_decay"] = ema_model.decay
  366. progress_bar.set_postfix(**logs)
  367. accelerator.log(logs, step=global_step)
  368. progress_bar.close()
  369. accelerator.wait_for_everyone()
  370. # Generate sample images for visual inspection
  371. if accelerator.is_main_process:
  372. if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
  373. pipeline = DDPMPipeline(
  374. unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
  375. scheduler=noise_scheduler,
  376. )
  377. generator = torch.Generator(device=pipeline.device).manual_seed(0)
  378. # run pipeline in inference (sample random noise and denoise)
  379. images = pipeline(
  380. generator=generator,
  381. batch_size=args.eval_batch_size,
  382. output_type="numpy",
  383. ).images
  384. # denormalize the images and save to tensorboard
  385. images_processed = (images * 255).round().astype("uint8")
  386. accelerator.trackers[0].writer.add_images(
  387. "test_samples", images_processed.transpose(0, 3, 1, 2), epoch
  388. )
  389. if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
  390. # save the model
  391. pipeline.save_pretrained(args.output_dir)
  392. if args.push_to_hub:
  393. repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
  394. accelerator.wait_for_everyone()
  395. accelerator.end_training()
  396. if __name__ == "__main__":
  397. args = parse_args()
  398. main(args)

开始训练前对argparse部分代码修改:
将"--dataset_name"部分修改为你使用的数据集的文件路径位置;

将"--train_data_dir"部分修改为你使用的数据集的文件路径位置;

将"--output_dir"部分修改为你训练得到模型文件夹的输出名称;

将"--resolution"部分修改为你想要后续生成的假图像的大小分辨率,比较考验电脑算力,可以自行尝试不同分辨率的,可以从64*64不断往上尝试,比如:64*64、128*128、256*256、512*512等

将"--train_batch_size"、"--eval_batch_size"、"--dataloader_num_workers"这三个参数与训练是否能继续息息相关,DDPM模型较大,对电脑算力要求高,显存不足时需要调小这三个参数才能正常运行程序;

将"--num_epochs"修改为你想要训练的轮数,不同数据集要求不同,数据集简单的100轮可能可以搞定,数据集复杂得自行尝试确定轮数。

总而言之,看菜吃饭,算力充足的可以任意调整,算力有限则需要慢慢调试参数进行运行训练。

新建python文件,用于生产“假”图像

generate.py

  1. # !pip install diffusers
  2. from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
  3. import os
  4. model_id = "ddpm-model-512-Huanglong_disease"
  5. # 生成的图像放的位置
  6. img_path = 'results' + '/' + model_id + '-img'
  7. if not os.path.exists(img_path): os.mkdir(img_path)
  8. device = "cuda"
  9. # load model and scheduler
  10. ddpm = DDPMPipeline.from_pretrained(
  11. model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference
  12. ddpm.to(device)
  13. for i in range(1000):
  14. # run pipeline in inference (sample random noise and denoise)
  15. image = ddpm().images[0]
  16. # save image
  17. # 不修改格式
  18. image.save(os.path.join(img_path,f'{i}.png'))
  19. # 改成单通道
  20. #image.convert('L').save(os.path.join(img_path, f'{i}.png'))
  21. # 看看跑到哪里了
  22. if i % 10 == 0: print(f"i={i}")

model_id为项目根目录下,训练好的模型文件夹名称。

修改for i in range()的循环次数,可以指定为你想要生成的合成图像的数量

生成彩色图像则使用

image.save(os.path.join(img_path,f'{i}.png'))

生成黑白图像则使用

#image.convert('L').save(os.path.join(img_path, f'{i}.png'))

两种各自使用,需要屏蔽另外一种的代码。

部分生成的图像:

                                                                          黄龙病

缺镁:

正常:

此外,由于是train_unconditional,当使用多个种类的数据集一起训练时,生成的图像类型不可控(即不能做到生成指定类型的图像数据),且由于扩散模型训练的是扩散(去噪)的模型,本身不带有鉴别的功能,训练生成的彩色合成图像可能带有一定的色差,我的建议是,先使用ResNet50等分类网络对原始数据集进行训练,得到一个类似于GAN算法中鉴别器的模型,再对生成好的合成图像进行分类并且筛选掉质量较差的图像。

三、小结

总的来说,对比效果实现数据扩容的合成图像方向的算法GAN,GAN是同时训练生成器和鉴别器,让两者相互作用进而实现拟合,但是大多数情况下GAN算法很难训练到很高程度的拟合,且常常会出现梯度消失或者梯度爆炸等不稳定的情况。

DDPM(扩散模型)训练的是模型对图像的去噪能力,比较容易拟合,训练得到的模型,输入无序的高斯噪声再执行模型的去噪能力,来实现“生产”合成图像的功能。优点是训练得到的图像精度较高,与原始图像更相似;缺点是模型训练需要的算力较大,且生成合成图像所需时间较长。

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

如果本文对你有帮助,欢迎一键三连!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/321867
推荐阅读
相关标签
  

闽ICP备14008679号