当前位置:   article > 正文

基于深度学习神经网络的AI图片上色DDcolor系统源码

基于深度学习神经网络的AI图片上色DDcolor系统源码

第一步:DDcolor介绍

        DDColor 是最新的 SOTA 图像上色算法,能够对输入的黑白图像生成自然生动的彩色结果,使用 UNet 结构的骨干网络和图像解码器分别实现图像特征提取和特征图上采样,并利用 Transformer 结构的颜色解码器完成基于视觉语义的颜色查询,最终聚合输出彩色通道预测结果。

        它甚至可以对动漫游戏中的风景进行着色/重新着色,将您的动画风景转变为逼真的现实生活风格!(图片来源:原神)

第二步:DDcolor网络结构

        算法整体流程如下图,使用 UNet 结构的骨干网络和图像解码器分别实现图像特征提取和特征图上采样,并利用 Transformer 结构的颜色解码器完成基于视觉语义的颜色查询,最终聚合输出彩色通道预测结果。

第三步:模型代码展示

  1. import os
  2. import torch
  3. from collections import OrderedDict
  4. from os import path as osp
  5. from tqdm import tqdm
  6. import numpy as np
  7. from basicsr.archs import build_network
  8. from basicsr.losses import build_loss
  9. from basicsr.metrics import calculate_metric
  10. from basicsr.utils import get_root_logger, imwrite, tensor2img
  11. from basicsr.utils.img_util import tensor_lab2rgb
  12. from basicsr.utils.dist_util import master_only
  13. from basicsr.utils.registry import MODEL_REGISTRY
  14. from .base_model import BaseModel
  15. from basicsr.metrics.custom_fid import INCEPTION_V3_FID, get_activations, calculate_activation_statistics, calculate_frechet_distance
  16. from basicsr.utils.color_enhance import color_enhacne_blend
  17. @MODEL_REGISTRY.register()
  18. class ColorModel(BaseModel):
  19. """Colorization model for single image colorization."""
  20. def __init__(self, opt):
  21. super(ColorModel, self).__init__(opt)
  22. # define network net_g
  23. self.net_g = build_network(opt['network_g'])
  24. self.net_g = self.model_to_device(self.net_g)
  25. self.print_network(self.net_g)
  26. # load pretrained model for net_g
  27. load_path = self.opt['path'].get('pretrain_network_g', None)
  28. if load_path is not None:
  29. param_key = self.opt['path'].get('param_key_g', 'params')
  30. self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
  31. if self.is_train:
  32. self.init_training_settings()
  33. def init_training_settings(self):
  34. train_opt = self.opt['train']
  35. self.ema_decay = train_opt.get('ema_decay', 0)
  36. if self.ema_decay > 0:
  37. logger = get_root_logger()
  38. logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
  39. # define network net_g with Exponential Moving Average (EMA)
  40. # net_g_ema is used only for testing on one GPU and saving
  41. # There is no need to wrap with DistributedDataParallel
  42. self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
  43. # load pretrained model
  44. load_path = self.opt['path'].get('pretrain_network_g', None)
  45. if load_path is not None:
  46. self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
  47. else:
  48. self.model_ema(0) # copy net_g weight
  49. self.net_g_ema.eval()
  50. # define network net_d
  51. self.net_d = build_network(self.opt['network_d'])
  52. self.net_d = self.model_to_device(self.net_d)
  53. self.print_network(self.net_d)
  54. # load pretrained model for net_d
  55. load_path = self.opt['path'].get('pretrain_network_d', None)
  56. if load_path is not None:
  57. param_key = self.opt['path'].get('param_key_d', 'params')
  58. self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
  59. self.net_g.train()
  60. self.net_d.train()
  61. # define losses
  62. if train_opt.get('pixel_opt'):
  63. self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
  64. else:
  65. self.cri_pix = None
  66. if train_opt.get('perceptual_opt'):
  67. self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
  68. else:
  69. self.cri_perceptual = None
  70. if train_opt.get('gan_opt'):
  71. self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
  72. else:
  73. self.cri_gan = None
  74. if self.cri_pix is None and self.cri_perceptual is None:
  75. raise ValueError('Both pixel and perceptual losses are None.')
  76. if train_opt.get('colorfulness_opt'):
  77. self.cri_colorfulness = build_loss(train_opt['colorfulness_opt']).to(self.device)
  78. else:
  79. self.cri_colorfulness = None
  80. # set up optimizers and schedulers
  81. self.setup_optimizers()
  82. self.setup_schedulers()
  83. # set real dataset cache for fid metric computing
  84. self.real_mu, self.real_sigma = None, None
  85. if self.opt['val'].get('metrics') is not None and self.opt['val']['metrics'].get('fid') is not None:
  86. self._prepare_inception_model_fid()
  87. def setup_optimizers(self):
  88. train_opt = self.opt['train']
  89. # optim_params_g = []
  90. # for k, v in self.net_g.named_parameters():
  91. # if v.requires_grad:
  92. # optim_params_g.append(v)
  93. # else:
  94. # logger = get_root_logger()
  95. # logger.warning(f'Params {k} will not be optimized.')
  96. optim_params_g = self.net_g.parameters()
  97. # optimizer g
  98. optim_type = train_opt['optim_g'].pop('type')
  99. self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
  100. self.optimizers.append(self.optimizer_g)
  101. # optimizer d
  102. optim_type = train_opt['optim_d'].pop('type')
  103. self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
  104. self.optimizers.append(self.optimizer_d)
  105. def feed_data(self, data):
  106. self.lq = data['lq'].to(self.device)
  107. self.lq_rgb = tensor_lab2rgb(torch.cat([self.lq, torch.zeros_like(self.lq), torch.zeros_like(self.lq)], dim=1))
  108. if 'gt' in data:
  109. self.gt = data['gt'].to(self.device)
  110. self.gt_lab = torch.cat([self.lq, self.gt], dim=1)
  111. self.gt_rgb = tensor_lab2rgb(self.gt_lab)
  112. if self.opt['train'].get('color_enhance', False):
  113. for i in range(self.gt_rgb.shape[0]):
  114. self.gt_rgb[i] = color_enhacne_blend(self.gt_rgb[i], factor=self.opt['train'].get('color_enhance_factor'))
  115. def optimize_parameters(self, current_iter):
  116. # optimize net_g
  117. for p in self.net_d.parameters():
  118. p.requires_grad = False
  119. self.optimizer_g.zero_grad()
  120. self.output_ab = self.net_g(self.lq_rgb)
  121. self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)
  122. self.output_rgb = tensor_lab2rgb(self.output_lab)
  123. l_g_total = 0
  124. loss_dict = OrderedDict()
  125. # pixel loss
  126. if self.cri_pix:
  127. l_g_pix = self.cri_pix(self.output_ab, self.gt)
  128. l_g_total += l_g_pix
  129. loss_dict['l_g_pix'] = l_g_pix
  130. # perceptual loss
  131. if self.cri_perceptual:
  132. l_g_percep, l_g_style = self.cri_perceptual(self.output_rgb, self.gt_rgb)
  133. if l_g_percep is not None:
  134. l_g_total += l_g_percep
  135. loss_dict['l_g_percep'] = l_g_percep
  136. if l_g_style is not None:
  137. l_g_total += l_g_style
  138. loss_dict['l_g_style'] = l_g_style
  139. # gan loss
  140. if self.cri_gan:
  141. fake_g_pred = self.net_d(self.output_rgb)
  142. l_g_gan = self.cri_gan(fake_g_pred, target_is_real=True, is_disc=False)
  143. l_g_total += l_g_gan
  144. loss_dict['l_g_gan'] = l_g_gan
  145. # colorfulness loss
  146. if self.cri_colorfulness:
  147. l_g_color = self.cri_colorfulness(self.output_rgb)
  148. l_g_total += l_g_color
  149. loss_dict['l_g_color'] = l_g_color
  150. l_g_total.backward()
  151. self.optimizer_g.step()
  152. # optimize net_d
  153. for p in self.net_d.parameters():
  154. p.requires_grad = True
  155. self.optimizer_d.zero_grad()
  156. real_d_pred = self.net_d(self.gt_rgb)
  157. fake_d_pred = self.net_d(self.output_rgb.detach())
  158. l_d = self.cri_gan(real_d_pred, target_is_real=True, is_disc=True) + self.cri_gan(fake_d_pred, target_is_real=False, is_disc=True)
  159. loss_dict['l_d'] = l_d
  160. loss_dict['real_score'] = real_d_pred.detach().mean()
  161. loss_dict['fake_score'] = fake_d_pred.detach().mean()
  162. l_d.backward()
  163. self.optimizer_d.step()
  164. self.log_dict = self.reduce_loss_dict(loss_dict)
  165. if self.ema_decay > 0:
  166. self.model_ema(decay=self.ema_decay)
  167. def get_current_visuals(self):
  168. out_dict = OrderedDict()
  169. out_dict['lq'] = self.lq_rgb.detach().cpu()
  170. out_dict['result'] = self.output_rgb.detach().cpu()
  171. if self.opt['logger'].get('save_snapshot_verbose', False): # only for verbose
  172. self.output_lab_chroma = torch.cat([torch.ones_like(self.lq) * 50, self.output_ab], dim=1)
  173. self.output_rgb_chroma = tensor_lab2rgb(self.output_lab_chroma)
  174. out_dict['result_chroma'] = self.output_rgb_chroma.detach().cpu()
  175. if hasattr(self, 'gt'):
  176. out_dict['gt'] = self.gt_rgb.detach().cpu()
  177. if self.opt['logger'].get('save_snapshot_verbose', False): # only for verbose
  178. self.gt_lab_chroma = torch.cat([torch.ones_like(self.lq) * 50, self.gt], dim=1)
  179. self.gt_rgb_chroma = tensor_lab2rgb(self.gt_lab_chroma)
  180. out_dict['gt_chroma'] = self.gt_rgb_chroma.detach().cpu()
  181. return out_dict
  182. def test(self):
  183. if hasattr(self, 'net_g_ema'):
  184. self.net_g_ema.eval()
  185. with torch.no_grad():
  186. self.output_ab = self.net_g_ema(self.lq_rgb)
  187. self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)
  188. self.output_rgb = tensor_lab2rgb(self.output_lab)
  189. else:
  190. self.net_g.eval()
  191. with torch.no_grad():
  192. self.output_ab = self.net_g(self.lq_rgb)
  193. self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)
  194. self.output_rgb = tensor_lab2rgb(self.output_lab)
  195. self.net_g.train()
  196. def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
  197. if self.opt['rank'] == 0:
  198. self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
  199. def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
  200. dataset_name = dataloader.dataset.opt['name']
  201. with_metrics = self.opt['val'].get('metrics') is not None
  202. use_pbar = self.opt['val'].get('pbar', False)
  203. if with_metrics and not hasattr(self, 'metric_results'): # only execute in the first run
  204. self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
  205. # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
  206. if with_metrics:
  207. self._initialize_best_metric_results(dataset_name)
  208. # zero self.metric_results
  209. if with_metrics:
  210. self.metric_results = {metric: 0 for metric in self.metric_results}
  211. metric_data = dict()
  212. if use_pbar:
  213. pbar = tqdm(total=len(dataloader), unit='image')
  214. if self.opt['val']['metrics'].get('fid') is not None:
  215. fake_acts_set, acts_set = [], []
  216. for idx, val_data in enumerate(dataloader):
  217. # if idx == 100:
  218. # break
  219. img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
  220. if hasattr(self, 'gt'):
  221. del self.gt
  222. self.feed_data(val_data)
  223. self.test()
  224. visuals = self.get_current_visuals()
  225. sr_img = tensor2img([visuals['result']])
  226. metric_data['img'] = sr_img
  227. if 'gt' in visuals:
  228. gt_img = tensor2img([visuals['gt']])
  229. metric_data['img2'] = gt_img
  230. torch.cuda.empty_cache()
  231. if save_img:
  232. if self.opt['is_train']:
  233. save_dir = osp.join(self.opt['path']['visualization'], img_name)
  234. for key in visuals:
  235. save_path = os.path.join(save_dir, '{}_{}.png'.format(current_iter, key))
  236. img = tensor2img(visuals[key])
  237. imwrite(img, save_path)
  238. else:
  239. if self.opt['val']['suffix']:
  240. save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
  241. f'{img_name}_{self.opt["val"]["suffix"]}.png')
  242. else:
  243. save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
  244. f'{img_name}_{self.opt["name"]}.png')
  245. imwrite(sr_img, save_img_path)
  246. if with_metrics:
  247. # calculate metrics
  248. for name, opt_ in self.opt['val']['metrics'].items():
  249. if name == 'fid':
  250. pred, gt = visuals['result'].cuda(), visuals['gt'].cuda()
  251. fake_act = get_activations(pred, self.inception_model_fid, 1)
  252. fake_acts_set.append(fake_act)
  253. if self.real_mu is None:
  254. real_act = get_activations(gt, self.inception_model_fid, 1)
  255. acts_set.append(real_act)
  256. else:
  257. self.metric_results[name] += calculate_metric(metric_data, opt_)
  258. if use_pbar:
  259. pbar.update(1)
  260. pbar.set_description(f'Test {img_name}')
  261. if use_pbar:
  262. pbar.close()
  263. if with_metrics:
  264. if self.opt['val']['metrics'].get('fid') is not None:
  265. if self.real_mu is None:
  266. acts_set = np.concatenate(acts_set, 0)
  267. self.real_mu, self.real_sigma = calculate_activation_statistics(acts_set)
  268. fake_acts_set = np.concatenate(fake_acts_set, 0)
  269. fake_mu, fake_sigma = calculate_activation_statistics(fake_acts_set)
  270. fid_score = calculate_frechet_distance(self.real_mu, self.real_sigma, fake_mu, fake_sigma)
  271. self.metric_results['fid'] = fid_score
  272. for metric in self.metric_results.keys():
  273. if metric != 'fid':
  274. self.metric_results[metric] /= (idx + 1)
  275. # update the best metric result
  276. self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
  277. self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
  278. def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
  279. log_str = f'Validation {dataset_name}\n'
  280. for metric, value in self.metric_results.items():
  281. log_str += f'\t # {metric}: {value:.4f}'
  282. if hasattr(self, 'best_metric_results'):
  283. log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
  284. f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
  285. log_str += '\n'
  286. logger = get_root_logger()
  287. logger.info(log_str)
  288. if tb_logger:
  289. for metric, value in self.metric_results.items():
  290. tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
  291. def _prepare_inception_model_fid(self, path='pretrain/inception_v3_google-1a9a5a14.pth'):
  292. incep_state_dict = torch.load(path, map_location='cpu')
  293. block_idx = INCEPTION_V3_FID.BLOCK_INDEX_BY_DIM[2048]
  294. self.inception_model_fid = INCEPTION_V3_FID(incep_state_dict, [block_idx])
  295. self.inception_model_fid.cuda()
  296. self.inception_model_fid.eval()
  297. @master_only
  298. def save_training_images(self, current_iter):
  299. visuals = self.get_current_visuals()
  300. save_dir = osp.join(self.opt['root_path'], 'experiments', self.opt['name'], 'training_images_snapshot')
  301. os.makedirs(save_dir, exist_ok=True)
  302. for key in visuals:
  303. save_path = os.path.join(save_dir, '{}_{}.png'.format(current_iter, key))
  304. img = tensor2img(visuals[key])
  305. imwrite(img, save_path)
  306. def save(self, epoch, current_iter):
  307. if hasattr(self, 'net_g_ema'):
  308. self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
  309. else:
  310. self.save_network(self.net_g, 'net_g', current_iter)
  311. self.save_network(self.net_d, 'net_d', current_iter)
  312. self.save_training_state(epoch, current_iter)

第四步:运行

第五步:整个工程的内容

代码的下载路径(新窗口打开链接)基于深度学习神经网络的AI图片上色DDcolor系统源码

有问题可以私信或者留言,有问必答

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/519569
推荐阅读
  

闽ICP备14008679号