当前位置:   article > 正文

超分辨率(3)--基于RCAN实现图像超分辨率重建_rcan网络

rcan网络

目录

一.项目介绍

二.项目流程详解

2.1.数据处理模块

2.2.损失函数设置 

2.3.网络模型构建

三.测试网络


一.项目介绍

RCAN:Residual Channel Attention Network(残差通道注意网络 )

卷积神经网络(CNN)的深度对于图像超分辨率(SR)是极其关键的因素。然而,我们观察到,更深层次的图像SR网络更难训练。低分辨率的输入和特征包含丰富的低频信息,这些信息在通道间被平等对待,从而阻碍了CNNs的表征能力。为了解决这些问题,我们提出了一种非常深的残差通道注意网络(RCAN)。具体地,我们提出了一种residual in residual(RIR)结构来形成非常深的网络,它由几个具有长跳连接的残差组组成。每个残差组包含一些具有短跳连接的残差块。与此同时,RIR允许大量的低频信息通过多个跳跃连接被绕过,使得主网络专注于学习高频信息。在此基础上,我们提出了一种通道注意机制,通过考虑通道间的相互依赖关系,自适应地重新调整通道特征。大量的实验表明,与比之前最先进的方法相比,我们的RCAN实现了更好的精度和视觉效果。

背景:

  • 卷积神经网络(CNN)的深度对于图像超分辨率(SR)是极其关键的因素。然而,作者观察到,更深层次的图像SR网络更难训练。
  • 低分辨率图像(LR)的输入和特征包含大量的低频信息,这些信息在通道间被平等对待,从而阻碍了CNNs的表征能力。

解决方案:

  • 对于第一个更深的网络更难训练的问题,作者研究发现,通过在网络中引入残差块,这种残差块使得网络达到了1000层,但是仅仅通过叠加残差块来构建更深的网络很难获得更好的提升效果。因此,作者提出了残差嵌套(residual in residual,RIR)结构构造非常深的可训练网络,RIR中的长跳连接和短跳连接有助于绕过大量的低频信息,使主网络学习到更有效的信息。
  • 对于第二个LR输入低频和高频信息在通道被平等对待的问题,作者发现注意力可以使可用处理资源的分配偏向于输入中信息量最大的部分,因此引入通道注意(Channel Attention ,CA)机制。

网络架构:

RCAN主要由四个部分组成:浅层特征提取、残差嵌套(RIR)深度特征提取、上采样模块和重建部分。 

  •  RIR组成:G个RG(带长跳连接)
  • 每个RG:B个RCAB组成(带短跳连接)
  • 每个RCAB组成:Conv + ReLU + Conv + CA
  • CA组成:Global pooling + Conv + ReLU + Conv  

名词解释:

  • Residual Channel Attention Network,RCAN:残差通道注意网络
  • residual in residua,RIR:残差嵌套
  • residual groups,RG:残差组
  • Residual Channel Attention Block,RCAB:残差通道注意块
  • Channel Attention,CA:通道注意
  • long skip connection,LSC:长跳连接
  • short skip connection,SSC:短跳连接

论文地址:

[1807.02758] Image Super-Resolution Using Very Deep Residual Channel Attention Networks (arxiv.org)icon-default.png?t=N7T8https://arxiv.org/abs/1807.02758

参考文章: 

RCAN论文笔记:Image Super-Resolution Using Very Deep Residual Channel Attention Networks-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/weixin_46773169/article/details/105600346

源码地址:

yulunzhang/RCAN: PyTorch code for our ECCV 2018 paper "Image Super-Resolution Using Very Deep Residual Channel Attention Networks" (github.com)icon-default.png?t=N7T8https://github.com/yulunzhang/RCAN

二.项目流程详解

2.1.数据处理模块

_init_.py

  1. from importlib import import_module
  2. from dataloader import MSDataLoader
  3. from torch.utils.data.dataloader import default_collate
  4. class Data:
  5. def __init__(self, args):
  6. kwargs = {}
  7. # 如果不在cpu上训练
  8. if not args.cpu:
  9. kwargs['collate_fn'] = default_collate
  10. kwargs['pin_memory'] = True
  11. # 在cpu上训练
  12. else:
  13. kwargs['collate_fn'] = default_collate
  14. kwargs['pin_memory'] = False
  15. self.loader_train = None
  16. if not args.test_only:
  17. # .lower()将大写字母转换为小写字母
  18. module_train = import_module('data.' + args.data_train.lower())
  19. # getattr() 函数用于返回一个对象属性值。
  20. trainset = getattr(module_train, args.data_train)(args)
  21. self.loader_train = MSDataLoader(
  22. args,
  23. trainset,
  24. batch_size=args.batch_size,
  25. shuffle=True,
  26. **kwargs
  27. )
  28. # 针对特殊的数据
  29. if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100']:
  30. if not args.benchmark_noise:
  31. module_test = import_module('data.benchmark')
  32. testset = getattr(module_test, 'Benchmark')(args, train=False)
  33. else:
  34. module_test = import_module('data.benchmark_noise')
  35. testset = getattr(module_test, 'BenchmarkNoise')(
  36. args,
  37. train=False
  38. )
  39. else:
  40. module_test = import_module('data.' + args.data_test.lower())
  41. testset = getattr(module_test, args.data_test)(args, train=False)
  42. # 对于自定义的MSDataLoader,主要需要传入的参数为args和dataset
  43. self.loader_test = MSDataLoader(
  44. args,
  45. testset,
  46. batch_size=1,
  47. shuffle=False,
  48. **kwargs
  49. )
  50. '''
  51. class MSDataLoader(DataLoader):
  52. def __init__(
  53. self, args, dataset, batch_size=1, shuffle=False,
  54. sampler=None, batch_sampler=None,
  55. collate_fn=default_collate, pin_memory=False, drop_last=False,
  56. timeout=0, worker_init_fn=None):
  57. super(MSDataLoader, self).__init__(
  58. dataset, batch_size=batch_size, shuffle=shuffle,
  59. sampler=sampler, batch_sampler=batch_sampler,
  60. num_workers=args.n_threads, collate_fn=collate_fn,
  61. pin_memory=pin_memory, drop_last=drop_last,
  62. timeout=timeout, worker_init_fn=worker_init_fn)
  63. self.scale = args.scale
  64. def __iter__(self):
  65. return _MSDataLoaderIter(self)
  66. '''

benchmark.py

  1. import os
  2. from data import common
  3. from data import srdata
  4. import numpy as np
  5. import scipy.misc as misc
  6. import torch
  7. import torch.utils.data as data
  8. class Benchmark(srdata.SRData):
  9. def __init__(self, args, train=True):
  10. super(Benchmark, self).__init__(args, train, benchmark=True)
  11. # 扫描磁盘得到数据
  12. def _scan(self):
  13. list_hr = []
  14. list_lr = [[] for _ in self.scale]
  15. for entry in os.scandir(self.dir_hr):
  16. # os.path.splitext分离文件名字和文件类型
  17. # eg: os.path.splitext(abc.txt) 得到的为('abc','txt')
  18. # filename取出的是文件名
  19. filename = os.path.splitext(entry.name)[0]
  20. # filename + self.ext 为文件的完整名字
  21. # os.path.join用于拼接文件路径,可以传入多个路径
  22. # 此处append的文件路径即为 self.dir_hr + (filename+self.ext)
  23. list_hr.append(os.path.join(self.dir_hr, filename + self.ext))
  24. for si, s in enumerate(self.scale):
  25. list_lr[si].append(os.path.join(
  26. self.dir_lr,
  27. 'X{}/{}x{}{}'.format(s, filename, s, self.ext)
  28. ))
  29. # 对取出的数据进行升序排列
  30. list_hr.sort()
  31. for l in list_lr:
  32. l.sort()
  33. return list_hr, list_lr
  34. # 设置数据的地址以及数据的类型
  35. def _set_filesystem(self, dir_data):
  36. self.apath = os.path.join(dir_data, 'benchmark', self.args.data_test)
  37. self.dir_hr = os.path.join(self.apath, 'HR')
  38. self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
  39. self.ext = '.png'

 common.py

  1. import random
  2. import numpy as np
  3. import skimage.io as sio
  4. import skimage.color as sc
  5. import skimage.transform as st
  6. import torch
  7. from torchvision import transforms
  8. def get_patch(img_in, img_tar, patch_size, scale, multi_scale=False):
  9. # shape得到图片的高度、宽度、颜色通道
  10. # 所以shape[:2}就是获取图片的前两个维度,获得图片的高度和宽度
  11. ih, iw = img_in.shape[:2]
  12. p = scale if multi_scale else 1
  13. tp = p * patch_size
  14. ip = tp // scale
  15. ix = random.randrange(0, iw - ip + 1)
  16. iy = random.randrange(0, ih - ip + 1)
  17. tx, ty = scale * ix, scale * iy
  18. img_in = img_in[iy:iy + ip, ix:ix + ip, :]
  19. img_tar = img_tar[ty:ty + tp, tx:tx + tp, :]
  20. return img_in, img_tar
  21. # 设置channel值
  22. def set_channel(l, n_channel):
  23. def _set_channel(img):
  24. if img.ndim == 2:
  25. # expand_dims(a, axis)中,a为numpy数组,axis为需添加维度的轴
  26. # 使数据增加一个维度
  27. img = np.expand_dims(img, axis=2)
  28. c = img.shape[2]
  29. if n_channel == 1 and c == 3:
  30. img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
  31. elif n_channel == 3 and c == 1:
  32. # numpy.concatenate((a1,a2,...), axis=0)函数。
  33. # 能 够一次完成多个数组的拼接。其中a1,a2,...是数组类型的参数
  34. img = np.concatenate([img] * n_channel, 2)
  35. return img
  36. return [_set_channel(_l) for _l in l]
  37. # 将np.array类型转为tensor类型
  38. def np2Tensor(l, rgb_range):
  39. def _np2Tensor(img):
  40. # ascontiguousarray函数将一个内存不连续存储的数组转换为内存连续存储的数组,使得运行速度更快
  41. # img.transpose((2,0,1))将图片的维度由(0,1,2)转换为(2,0,1)
  42. np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
  43. tensor = torch.from_numpy(np_transpose).float()
  44. tensor.mul_(rgb_range / 255)
  45. return tensor
  46. return [_np2Tensor(_l) for _l in l]
  47. def add_noise(x, noise='.'):
  48. if noise is not '.':
  49. noise_type = noise[0]
  50. noise_value = int(noise[1:])
  51. if noise_type == 'G':
  52. noises = np.random.normal(scale=noise_value, size=x.shape)
  53. noises = noises.round()
  54. elif noise_type == 'S':
  55. noises = np.random.poisson(x * noise_value) / noise_value
  56. noises = noises - noises.mean(axis=0).mean(axis=0)
  57. x_noise = x.astype(np.int16) + noises.astype(np.int16)
  58. x_noise = x_noise.clip(0, 255).astype(np.uint8)
  59. return x_noise
  60. else:
  61. return x
  62. def augment(l, hflip=True, rot=True):
  63. hflip = hflip and random.random() < 0.5
  64. vflip = rot and random.random() < 0.5
  65. rot90 = rot and random.random() < 0.5
  66. def _augment(img):
  67. if hflip: img = img[:, ::-1, :]
  68. if vflip: img = img[::-1, :, :]
  69. if rot90: img = img.transpose(1, 0, 2)
  70. return img
  71. return [_augment(_l) for _l in l]

demo.py 

  1. import os
  2. from data import common
  3. import numpy as np
  4. import scipy.misc as misc
  5. import torch
  6. import torch.utils.data as data
  7. class Demo(data.Dataset):
  8. def __init__(self, args, train=False):
  9. self.args = args
  10. self.name = 'Demo'
  11. self.scale = args.scale
  12. self.idx_scale = 0
  13. self.train = False
  14. self.benchmark = False
  15. self.filelist = []
  16. for f in os.listdir(args.dir_demo):
  17. if f.find('.png') >= 0 or f.find('.jp') >= 0:
  18. self.filelist.append(os.path.join(args.dir_demo, f))
  19. self.filelist.sort()
  20. def __getitem__(self, idx):
  21. filename = os.path.split(self.filelist[idx])[-1]
  22. filename, _ = os.path.splitext(filename)
  23. lr = misc.imread(self.filelist[idx])
  24. lr = common.set_channel([lr], self.args.n_colors)[0]
  25. return common.np2Tensor([lr], self.args.rgb_range)[0], -1, filename
  26. def __len__(self):
  27. return len(self.filelist)
  28. def set_scale(self, idx_scale):
  29. self.idx_scale = idx_scale

srdata.py

  1. import os
  2. from data import common
  3. import numpy as np
  4. import scipy.misc as misc
  5. import torch
  6. import torch.utils.data as data
  7. class SRData(data.Dataset):
  8. def __init__(self, args, train=True, benchmark=False):
  9. self.args = args
  10. self.train = train
  11. self.split = 'train' if train else 'test'
  12. self.benchmark = benchmark
  13. self.scale = args.scale
  14. self.idx_scale = 0
  15. self._set_filesystem(args.dir_data)
  16. def _load_bin():
  17. self.images_hr = np.load(self._name_hrbin())
  18. self.images_lr = [
  19. np.load(self._name_lrbin(s)) for s in self.scale
  20. ]
  21. if args.ext == 'img' or benchmark:
  22. self.images_hr, self.images_lr = self._scan()
  23. elif args.ext.find('sep') >= 0:
  24. self.images_hr, self.images_lr = self._scan()
  25. if args.ext.find('reset') >= 0:
  26. print('Preparing seperated binary files')
  27. for v in self.images_hr:
  28. hr = misc.imread(v)
  29. name_sep = v.replace(self.ext, '.npy')
  30. np.save(name_sep, hr)
  31. for si, s in enumerate(self.scale):
  32. for v in self.images_lr[si]:
  33. lr = misc.imread(v)
  34. name_sep = v.replace(self.ext, '.npy')
  35. np.save(name_sep, lr)
  36. self.images_hr = [
  37. v.replace(self.ext, '.npy') for v in self.images_hr
  38. ]
  39. self.images_lr = [
  40. [v.replace(self.ext, '.npy') for v in self.images_lr[i]]
  41. for i in range(len(self.scale))
  42. ]
  43. elif args.ext.find('bin') >= 0:
  44. try:
  45. if args.ext.find('reset') >= 0:
  46. raise IOError
  47. print('Loading a binary file')
  48. _load_bin()
  49. except:
  50. print('Preparing a binary file')
  51. bin_path = os.path.join(self.apath, 'bin')
  52. if not os.path.isdir(bin_path):
  53. os.mkdir(bin_path)
  54. list_hr, list_lr = self._scan()
  55. hr = [misc.imread(f) for f in list_hr]
  56. np.save(self._name_hrbin(), hr)
  57. del hr
  58. for si, s in enumerate(self.scale):
  59. lr_scale = [misc.imread(f) for f in list_lr[si]]
  60. np.save(self._name_lrbin(s), lr_scale)
  61. del lr_scale
  62. _load_bin()
  63. else:
  64. print('Please define data type')
  65. def _scan(self):
  66. raise NotImplementedError
  67. def _set_filesystem(self, dir_data):
  68. raise NotImplementedError
  69. def _name_hrbin(self):
  70. raise NotImplementedError
  71. def _name_lrbin(self, scale):
  72. raise NotImplementedError
  73. def __getitem__(self, idx):
  74. lr, hr, filename = self._load_file(idx)
  75. lr, hr = self._get_patch(lr, hr)
  76. lr, hr = common.set_channel([lr, hr], self.args.n_colors)
  77. lr_tensor, hr_tensor = common.np2Tensor([lr, hr], self.args.rgb_range)
  78. return lr_tensor, hr_tensor, filename
  79. def __len__(self):
  80. return len(self.images_hr)
  81. def _get_index(self, idx):
  82. return idx
  83. def _load_file(self, idx):
  84. idx = self._get_index(idx)
  85. lr = self.images_lr[self.idx_scale][idx]
  86. hr = self.images_hr[idx]
  87. if self.args.ext == 'img' or self.benchmark:
  88. filename = hr
  89. lr = misc.imread(lr)
  90. hr = misc.imread(hr)
  91. elif self.args.ext.find('sep') >= 0:
  92. filename = hr
  93. lr = np.load(lr)
  94. hr = np.load(hr)
  95. else:
  96. filename = str(idx + 1)
  97. filename = os.path.splitext(os.path.split(filename)[-1])[0]
  98. return lr, hr, filename
  99. def _get_patch(self, lr, hr):
  100. patch_size = self.args.patch_size
  101. scale = self.scale[self.idx_scale]
  102. multi_scale = len(self.scale) > 1
  103. if self.train:
  104. lr, hr = common.get_patch(
  105. lr, hr, patch_size, scale, multi_scale=multi_scale
  106. )
  107. lr, hr = common.augment([lr, hr])
  108. lr = common.add_noise(lr, self.args.noise)
  109. else:
  110. ih, iw = lr.shape[0:2]
  111. hr = hr[0:ih * scale, 0:iw * scale]
  112. return lr, hr
  113. def set_scale(self, idx_scale):
  114. self.idx_scale = idx_scale

div2k.py 

  1. import os
  2. from data import common
  3. from data import srdata
  4. import numpy as np
  5. import scipy.misc as misc
  6. import torch
  7. import torch.utils.data as data
  8. class DIV2K(srdata.SRData):
  9. def __init__(self, args, train=True):
  10. super(DIV2K, self).__init__(args, train)
  11. self.repeat = args.test_every // (args.n_train // args.batch_size)
  12. def _scan(self):
  13. list_hr = []
  14. list_lr = [[] for _ in self.scale]
  15. if self.train:
  16. idx_begin = 0
  17. idx_end = self.args.n_train
  18. else:
  19. idx_begin = self.args.n_train
  20. idx_end = self.args.offset_val + self.args.n_val
  21. for i in range(idx_begin + 1, idx_end + 1):
  22. filename = '{:0>4}'.format(i)
  23. list_hr.append(os.path.join(self.dir_hr, filename + self.ext))
  24. for si, s in enumerate(self.scale):
  25. list_lr[si].append(os.path.join(
  26. self.dir_lr,
  27. 'X{}/{}x{}{}'.format(s, filename, s, self.ext)
  28. ))
  29. return list_hr, list_lr
  30. def _set_filesystem(self, dir_data):
  31. self.apath = dir_data + '/DIV2K'
  32. self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
  33. self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
  34. self.ext = '.png'
  35. def _name_hrbin(self):
  36. return os.path.join(
  37. self.apath,
  38. 'bin',
  39. '{}_bin_HR.npy'.format(self.split)
  40. )
  41. def _name_lrbin(self, scale):
  42. return os.path.join(
  43. self.apath,
  44. 'bin',
  45. '{}_bin_LR_X{}.npy'.format(self.split, scale)
  46. )
  47. def __len__(self):
  48. if self.train:
  49. return len(self.images_hr) * self.repeat
  50. else:
  51. return len(self.images_hr)
  52. def _get_index(self, idx):
  53. if self.train:
  54. return idx % len(self.images_hr)
  55. else:
  56. return idx

2.2.损失函数设置 

_init_.py

  1. import os
  2. from importlib import import_module
  3. import matplotlib
  4. matplotlib.use('Agg')
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. class Loss(nn.modules.loss._Loss):
  11. def __init__(self, args, ckp):
  12. super(Loss, self).__init__()
  13. print('Preparing loss function:')
  14. self.n_GPUs = args.n_GPUs
  15. self.loss = []
  16. # 首先说说 nn.ModuleList 这个类,你可以把任意 nn.Module 的子类
  17. # (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,
  18. # 方法和 Python 自带的 list 一样,无非是 extend,append 等操作。
  19. # 但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是会自动注册到整个网络上的,
  20. # 同时 module 的 parameters 也会自动添加到整个网络中。
  21. self.loss_module = nn.ModuleList()
  22. # split(' ')根据括号里的字符分割字符串
  23. for loss in args.loss.split('+'):
  24. weight, loss_type = loss.split('*')
  25. if loss_type == 'MSE':
  26. loss_function = nn.MSELoss()
  27. elif loss_type == 'L1':
  28. loss_function = nn.L1Loss()
  29. elif loss_type.find('VGG') >= 0:
  30. module = import_module('loss.vgg')
  31. loss_function = getattr(module, 'VGG')(
  32. loss_type[3:],
  33. rgb_range=args.rgb_range
  34. )
  35. elif loss_type.find('GAN') >= 0:
  36. module = import_module('loss.adversarial')
  37. loss_function = getattr(module, 'Adversarial')(
  38. args,
  39. loss_type
  40. )
  41. self.loss.append({
  42. 'type': loss_type,
  43. 'weight': float(weight),
  44. 'function': loss_function}
  45. )
  46. if loss_type.find('GAN') >= 0:
  47. self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
  48. if len(self.loss) > 1:
  49. self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
  50. for l in self.loss:
  51. if l['function'] is not None:
  52. print('{:.3f} * {}'.format(l['weight'], l['type']))
  53. self.loss_module.append(l['function'])
  54. self.log = torch.Tensor()
  55. device = torch.device('cpu' if args.cpu else 'cuda')
  56. self.loss_module.to(device)
  57. if args.precision == 'half': self.loss_module.half()
  58. if not args.cpu and args.n_GPUs > 1:
  59. self.loss_module = nn.DataParallel(
  60. self.loss_module, range(args.n_GPUs)
  61. )
  62. if args.load != '.': self.load(ckp.dir, cpu=args.cpu)
  63. def forward(self, sr, hr):
  64. losses = []
  65. for i, l in enumerate(self.loss):
  66. if l['function'] is not None:
  67. loss = l['function'](sr, hr)
  68. effective_loss = l['weight'] * loss
  69. losses.append(effective_loss)
  70. self.log[-1, i] += effective_loss.item()
  71. elif l['type'] == 'DIS':
  72. self.log[-1, i] += self.loss[i - 1]['function'].loss
  73. loss_sum = sum(losses)
  74. if len(self.loss) > 1:
  75. self.log[-1, -1] += loss_sum.item()
  76. return loss_sum
  77. def step(self):
  78. for l in self.get_loss_module():
  79. if hasattr(l, 'scheduler'):
  80. l.scheduler.step()
  81. def start_log(self):
  82. self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
  83. def end_log(self, n_batches):
  84. self.log[-1].div_(n_batches)
  85. def display_loss(self, batch):
  86. n_samples = batch + 1
  87. log = []
  88. for l, c in zip(self.loss, self.log[-1]):
  89. log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
  90. return ''.join(log)
  91. def plot_loss(self, apath, epoch):
  92. axis = np.linspace(1, epoch, epoch)
  93. for i, l in enumerate(self.loss):
  94. label = '{} Loss'.format(l['type'])
  95. fig = plt.figure()
  96. plt.title(label)
  97. plt.plot(axis, self.log[:, i].numpy(), label=label)
  98. plt.legend()
  99. plt.xlabel('Epochs')
  100. plt.ylabel('Loss')
  101. plt.grid(True)
  102. plt.savefig('{}/loss_{}.pdf'.format(apath, l['type']))
  103. plt.close(fig)
  104. def get_loss_module(self):
  105. if self.n_GPUs == 1:
  106. return self.loss_module
  107. else:
  108. return self.loss_module.module
  109. def save(self, apath):
  110. torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
  111. torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
  112. def load(self, apath, cpu=False):
  113. if cpu:
  114. kwargs = {'map_location': lambda storage, loc: storage}
  115. else:
  116. kwargs = {}
  117. self.load_state_dict(torch.load(
  118. os.path.join(apath, 'loss.pt'),
  119. **kwargs
  120. ))
  121. self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
  122. for l in self.get_loss_module():
  123. if hasattr(l, 'scheduler'):
  124. for _ in range(len(self.log)): l.scheduler.step()

 adversarial.py

  1. import utility
  2. from model import common
  3. from loss import discriminator
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. import torch.optim as optim
  8. from torch.autograd import Variable
  9. class Adversarial(nn.Module):
  10. def __init__(self, args, gan_type):
  11. super(Adversarial, self).__init__()
  12. self.gan_type = gan_type
  13. self.gan_k = args.gan_k
  14. self.discriminator = discriminator.Discriminator(args, gan_type)
  15. if gan_type != 'WGAN_GP':
  16. self.optimizer = utility.make_optimizer(args, self.discriminator)
  17. else:
  18. self.optimizer = optim.Adam(
  19. self.discriminator.parameters(),
  20. betas=(0, 0.9), eps=1e-8, lr=1e-5
  21. )
  22. self.scheduler = utility.make_scheduler(args, self.optimizer)
  23. def forward(self, fake, real):
  24. fake_detach = fake.detach()
  25. self.loss = 0
  26. for _ in range(self.gan_k):
  27. self.optimizer.zero_grad()
  28. d_fake = self.discriminator(fake_detach)
  29. d_real = self.discriminator(real)
  30. if self.gan_type == 'GAN':
  31. label_fake = torch.zeros_like(d_fake)
  32. label_real = torch.ones_like(d_real)
  33. loss_d \
  34. = F.binary_cross_entropy_with_logits(d_fake, label_fake) \
  35. + F.binary_cross_entropy_with_logits(d_real, label_real)
  36. elif self.gan_type.find('WGAN') >= 0:
  37. loss_d = (d_fake - d_real).mean()
  38. if self.gan_type.find('GP') >= 0:
  39. epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
  40. hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
  41. hat.requires_grad = True
  42. d_hat = self.discriminator(hat)
  43. gradients = torch.autograd.grad(
  44. outputs=d_hat.sum(), inputs=hat,
  45. retain_graph=True, create_graph=True, only_inputs=True
  46. )[0]
  47. gradients = gradients.view(gradients.size(0), -1)
  48. gradient_norm = gradients.norm(2, dim=1)
  49. gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
  50. loss_d += gradient_penalty
  51. # Discriminator update
  52. self.loss += loss_d.item()
  53. loss_d.backward()
  54. self.optimizer.step()
  55. if self.gan_type == 'WGAN':
  56. for p in self.discriminator.parameters():
  57. p.data.clamp_(-1, 1)
  58. self.loss /= self.gan_k
  59. d_fake_for_g = self.discriminator(fake)
  60. if self.gan_type == 'GAN':
  61. loss_g = F.binary_cross_entropy_with_logits(
  62. d_fake_for_g, label_real
  63. )
  64. elif self.gan_type.find('WGAN') >= 0:
  65. loss_g = -d_fake_for_g.mean()
  66. # Generator loss
  67. return loss_g
  68. def state_dict(self, *args, **kwargs):
  69. state_discriminator = self.discriminator.state_dict(*args, **kwargs)
  70. state_optimizer = self.optimizer.state_dict()
  71. return dict(**state_discriminator, **state_optimizer)
  72. # Some references
  73. # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
  74. # OR
  75. # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py

discriminator.py 

  1. from model import common
  2. import torch.nn as nn
  3. class Discriminator(nn.Module):
  4. def __init__(self, args, gan_type='GAN'):
  5. super(Discriminator, self).__init__()
  6. in_channels = 3
  7. out_channels = 64
  8. depth = 7
  9. #bn = not gan_type == 'WGAN_GP'
  10. bn = True
  11. act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
  12. m_features = [
  13. common.BasicBlock(args.n_colors, out_channels, 3, bn=bn, act=act)
  14. ]
  15. for i in range(depth):
  16. in_channels = out_channels
  17. if i % 2 == 1:
  18. stride = 1
  19. out_channels *= 2
  20. else:
  21. stride = 2
  22. m_features.append(common.BasicBlock(
  23. in_channels, out_channels, 3, stride=stride, bn=bn, act=act
  24. ))
  25. self.features = nn.Sequential(*m_features)
  26. patch_size = args.patch_size // (2**((depth + 1) // 2))
  27. m_classifier = [
  28. nn.Linear(out_channels * patch_size**2, 1024),
  29. act,
  30. nn.Linear(1024, 1)
  31. ]
  32. self.classifier = nn.Sequential(*m_classifier)
  33. def forward(self, x):
  34. features = self.features(x)
  35. output = self.classifier(features.view(features.size(0), -1))
  36. return output

vgg.py 

  1. from model import common
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import torchvision.models as models
  6. from torch.autograd import Variable
  7. class VGG(nn.Module):
  8. def __init__(self, conv_index, rgb_range=1):
  9. super(VGG, self).__init__()
  10. # pretrained = True 表示使用已经训练过的参数
  11. vgg_features = models.vgg19(pretrained=True).features
  12. modules = [m for m in vgg_features]
  13. if conv_index == '22':
  14. self.vgg = nn.Sequential(*modules[:8])
  15. elif conv_index == '54':
  16. self.vgg = nn.Sequential(*modules[:35])
  17. vgg_mean = (0.485, 0.456, 0.406)
  18. vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
  19. self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
  20. self.vgg.requires_grad = False
  21. def forward(self, sr, hr):
  22. def _forward(x):
  23. x = self.sub_mean(x)
  24. x = self.vgg(x)
  25. return x
  26. vgg_sr = _forward(sr)
  27. with torch.no_grad():
  28. vgg_hr = _forward(hr.detach())
  29. loss = F.mse_loss(vgg_sr, vgg_hr)
  30. return loss

2.3.网络模型构建

dataloader.py

  1. import sys
  2. import threading
  3. import queue
  4. import random
  5. import collections
  6. import torch
  7. import torch.multiprocessing as multiprocessing
  8. from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
  9. _remove_worker_pids, _error_if_any_worker_fails
  10. from torch.utils.data.dataloader import DataLoader
  11. from torch.utils.data.dataloader import _DataLoaderIter
  12. from torch.utils.data.dataloader import ExceptionWrapper
  13. from torch.utils.data.dataloader import _use_shared_memory
  14. from torch.utils.data.dataloader import _worker_manager_loop
  15. from torch.utils.data.dataloader import numpy_type_map
  16. from torch.utils.data.dataloader import default_collate
  17. from torch.utils.data.dataloader import pin_memory_batch
  18. from torch.utils.data.dataloader import _SIGCHLD_handler_set
  19. from torch.utils.data.dataloader import _set_SIGCHLD_handler
  20. if sys.version_info[0] == 2:
  21. import Queue as queue
  22. else:
  23. import queue
  24. def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
  25. global _use_shared_memory
  26. _use_shared_memory = True
  27. _set_worker_signal_handlers()
  28. torch.set_num_threads(1)
  29. torch.manual_seed(seed)
  30. while True:
  31. r = index_queue.get()
  32. if r is None:
  33. break
  34. idx, batch_indices = r
  35. try:
  36. idx_scale = 0
  37. if len(scale) > 1 and dataset.train:
  38. idx_scale = random.randrange(0, len(scale))
  39. dataset.set_scale(idx_scale)
  40. samples = collate_fn([dataset[i] for i in batch_indices])
  41. samples.append(idx_scale)
  42. except Exception:
  43. data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
  44. else:
  45. data_queue.put((idx, samples))
  46. class _MSDataLoaderIter(_DataLoaderIter):
  47. def __init__(self, loader):
  48. self.dataset = loader.dataset
  49. self.scale = loader.scale
  50. self.collate_fn = loader.collate_fn
  51. self.batch_sampler = loader.batch_sampler
  52. self.num_workers = loader.num_workers
  53. self.pin_memory = loader.pin_memory and torch.cuda.is_available()
  54. self.timeout = loader.timeout
  55. self.done_event = threading.Event()
  56. self.sample_iter = iter(self.batch_sampler)
  57. if self.num_workers > 0:
  58. self.worker_init_fn = loader.worker_init_fn
  59. self.index_queues = [
  60. multiprocessing.Queue() for _ in range(self.num_workers)
  61. ]
  62. self.worker_queue_idx = 0
  63. self.worker_result_queue = multiprocessing.SimpleQueue()
  64. self.batches_outstanding = 0
  65. self.worker_pids_set = False
  66. self.shutdown = False
  67. self.send_idx = 0
  68. self.rcvd_idx = 0
  69. self.reorder_dict = {}
  70. base_seed = torch.LongTensor(1).random_()[0]
  71. self.workers = [
  72. multiprocessing.Process(
  73. target=_ms_loop,
  74. args=(
  75. self.dataset,
  76. self.index_queues[i],
  77. self.worker_result_queue,
  78. self.collate_fn,
  79. self.scale,
  80. base_seed + i,
  81. self.worker_init_fn,
  82. i
  83. )
  84. )
  85. for i in range(self.num_workers)]
  86. if self.pin_memory or self.timeout > 0:
  87. self.data_queue = queue.Queue()
  88. if self.pin_memory:
  89. maybe_device_id = torch.cuda.current_device()
  90. else:
  91. # do not initialize cuda context if not necessary
  92. maybe_device_id = None
  93. self.worker_manager_thread = threading.Thread(
  94. target=_worker_manager_loop,
  95. args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
  96. maybe_device_id))
  97. self.worker_manager_thread.daemon = True
  98. self.worker_manager_thread.start()
  99. else:
  100. self.data_queue = self.worker_result_queue
  101. for w in self.workers:
  102. w.daemon = True # ensure that the worker exits on process exit
  103. w.start()
  104. _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
  105. _set_SIGCHLD_handler()
  106. self.worker_pids_set = True
  107. # prime the prefetch loop
  108. for _ in range(2 * self.num_workers):
  109. self._put_indices()
  110. class MSDataLoader(DataLoader):
  111. def __init__(
  112. self, args, dataset, batch_size=1, shuffle=False,
  113. sampler=None, batch_sampler=None,
  114. collate_fn=default_collate, pin_memory=False, drop_last=False,
  115. timeout=0, worker_init_fn=None):
  116. super(MSDataLoader, self).__init__(
  117. dataset, batch_size=batch_size, shuffle=shuffle,
  118. sampler=sampler, batch_sampler=batch_sampler,
  119. num_workers=args.n_threads, collate_fn=collate_fn,
  120. pin_memory=pin_memory, drop_last=drop_last,
  121. timeout=timeout, worker_init_fn=worker_init_fn)
  122. self.scale = args.scale
  123. def __iter__(self):
  124. return _MSDataLoaderIter(self)

 main.py

  1. import torch
  2. import utility
  3. import data
  4. import model
  5. import loss
  6. from option import args
  7. from trainer import Trainer
  8. torch.manual_seed(args.seed)
  9. checkpoint = utility.checkpoint(args)
  10. if checkpoint.ok:
  11. loader = data.Data(args)
  12. model = model.Model(args, checkpoint)
  13. loss = loss.Loss(args, checkpoint) if not args.test_only else None
  14. t = Trainer(args, loader, model, loss, checkpoint)
  15. while not t.terminate():
  16. t.train()
  17. t.test()
  18. checkpoint.done()

 option.py

  1. import argparse
  2. import template
  3. parser = argparse.ArgumentParser(description='EDSR and MDSR')
  4. parser.add_argument('--debug', action='store_true',
  5. help='Enables debug mode')
  6. parser.add_argument('--template', default='.',
  7. help='You can set various templates in option.py')
  8. # Hardware specifications
  9. parser.add_argument('--n_threads', type=int, default=3,
  10. help='number of threads for data loading')
  11. parser.add_argument('--cpu', action='store_true',
  12. help='use cpu only')
  13. parser.add_argument('--n_GPUs', type=int, default=1,
  14. help='number of GPUs')
  15. parser.add_argument('--seed', type=int, default=1,
  16. help='random seed')
  17. # Data specifications
  18. parser.add_argument('--dir_data', type=str, default='/home/yulun/data/SR/traindata/DIV2K/bicubic',
  19. help='dataset directory')
  20. parser.add_argument('--dir_demo', type=str, default='../test',
  21. help='demo image directory')
  22. parser.add_argument('--data_train', type=str, default='DIV2K',
  23. help='train dataset name')
  24. parser.add_argument('--data_test', type=str, default='DIV2K',
  25. help='test dataset name')
  26. parser.add_argument('--benchmark_noise', action='store_true',
  27. help='use noisy benchmark sets')
  28. parser.add_argument('--n_train', type=int, default=800,
  29. help='number of training set')
  30. parser.add_argument('--n_val', type=int, default=5,
  31. help='number of validation set')
  32. parser.add_argument('--offset_val', type=int, default=800,
  33. help='validation index offest')
  34. parser.add_argument('--ext', type=str, default='sep_reset',
  35. help='dataset file extension')
  36. parser.add_argument('--scale', default='4',
  37. help='super resolution scale')
  38. parser.add_argument('--patch_size', type=int, default=192,
  39. help='output patch size')
  40. parser.add_argument('--rgb_range', type=int, default=255,
  41. help='maximum value of RGB')
  42. parser.add_argument('--n_colors', type=int, default=3,
  43. help='number of color channels to use')
  44. parser.add_argument('--noise', type=str, default='.',
  45. help='Gaussian noise std.')
  46. parser.add_argument('--chop', action='store_true',
  47. help='enable memory-efficient forward')
  48. # Model specifications
  49. parser.add_argument('--model', default='RCAN',
  50. help='model name')
  51. parser.add_argument('--act', type=str, default='relu',
  52. help='activation function')
  53. parser.add_argument('--pre_train', type=str, default='.',
  54. help='pre-trained model directory')
  55. parser.add_argument('--extend', type=str, default='.',
  56. help='pre-trained model directory')
  57. parser.add_argument('--n_resblocks', type=int, default=20,
  58. help='number of residual blocks')
  59. parser.add_argument('--n_feats', type=int, default=64,
  60. help='number of feature maps')
  61. parser.add_argument('--res_scale', type=float, default=1,
  62. help='residual scaling')
  63. parser.add_argument('--shift_mean', default=True,
  64. help='subtract pixel mean from the input')
  65. parser.add_argument('--precision', type=str, default='single',
  66. choices=('single', 'half'),
  67. help='FP precision for test (single | half)')
  68. # Training specifications
  69. parser.add_argument('--reset', action='store_true',
  70. help='reset the training')
  71. parser.add_argument('--test_every', type=int, default=1000,
  72. help='do test per every N batches')
  73. parser.add_argument('--epochs', type=int, default=1000,
  74. help='number of epochs to train')
  75. parser.add_argument('--batch_size', type=int, default=16,
  76. help='input batch size for training')
  77. parser.add_argument('--split_batch', type=int, default=1,
  78. help='split the batch into smaller chunks')
  79. parser.add_argument('--self_ensemble', action='store_true',
  80. help='use self-ensemble method for test')
  81. parser.add_argument('--test_only', action='store_true',
  82. help='set this option to test the model')
  83. parser.add_argument('--gan_k', type=int, default=1,
  84. help='k value for adversarial loss')
  85. # Optimization specifications
  86. parser.add_argument('--lr', type=float, default=1e-4,
  87. help='learning rate')
  88. parser.add_argument('--lr_decay', type=int, default=200,
  89. help='learning rate decay per N epochs')
  90. parser.add_argument('--decay_type', type=str, default='step',
  91. help='learning rate decay type')
  92. parser.add_argument('--gamma', type=float, default=0.5,
  93. help='learning rate decay factor for step decay')
  94. parser.add_argument('--optimizer', default='ADAM',
  95. choices=('SGD', 'ADAM', 'RMSprop'),
  96. help='optimizer to use (SGD | ADAM | RMSprop)')
  97. parser.add_argument('--momentum', type=float, default=0.9,
  98. help='SGD momentum')
  99. parser.add_argument('--beta1', type=float, default=0.9,
  100. help='ADAM beta1')
  101. parser.add_argument('--beta2', type=float, default=0.999,
  102. help='ADAM beta2')
  103. parser.add_argument('--epsilon', type=float, default=1e-8,
  104. help='ADAM epsilon for numerical stability')
  105. parser.add_argument('--weight_decay', type=float, default=0,
  106. help='weight decay')
  107. # Loss specifications
  108. parser.add_argument('--loss', type=str, default='1*L1',
  109. help='loss function configuration')
  110. parser.add_argument('--skip_threshold', type=float, default='1e6',
  111. help='skipping batch that has large error')
  112. # Log specifications
  113. parser.add_argument('--save', type=str, default='test',
  114. help='file name to save')
  115. parser.add_argument('--load', type=str, default='.',
  116. help='file name to load')
  117. parser.add_argument('--resume', type=int, default=0,
  118. help='resume from specific checkpoint')
  119. parser.add_argument('--print_model', action='store_true',
  120. help='print model')
  121. parser.add_argument('--save_models', action='store_true',
  122. help='save all intermediate models')
  123. parser.add_argument('--print_every', type=int, default=100,
  124. help='how many batches to wait before logging training status')
  125. parser.add_argument('--save_results', action='store_true',
  126. help='save output results')
  127. # options for residual group and feature channel reduction
  128. parser.add_argument('--n_resgroups', type=int, default=10,
  129. help='number of residual groups')
  130. parser.add_argument('--reduction', type=int, default=16,
  131. help='number of feature maps reduction')
  132. # options for test
  133. parser.add_argument('--testpath', type=str, default='../test/DIV2K_val_LR_our',
  134. help='dataset directory for testing')
  135. parser.add_argument('--testset', type=str, default='Set5',
  136. help='dataset name for testing')
  137. args = parser.parse_args()
  138. template.set_template(args)
  139. args.scale = list(map(lambda x: int(x), args.scale.split('+')))
  140. if args.epochs == 0:
  141. args.epochs = 1e8
  142. for arg in vars(args):
  143. if vars(args)[arg] == 'True':
  144. vars(args)[arg] = True
  145. elif vars(args)[arg] == 'False':
  146. vars(args)[arg] = False

template.py 

  1. def set_template(args):
  2. # Set the templates here
  3. if args.template.find('jpeg') >= 0:
  4. args.data_train = 'DIV2K_jpeg'
  5. args.data_test = 'DIV2K_jpeg'
  6. args.epochs = 200
  7. args.lr_decay = 100
  8. if args.template.find('EDSR_paper') >= 0:
  9. args.model = 'EDSR'
  10. args.n_resblocks = 32
  11. args.n_feats = 256
  12. args.res_scale = 0.1
  13. if args.template.find('MDSR') >= 0:
  14. args.model = 'MDSR'
  15. args.patch_size = 48
  16. args.epochs = 650
  17. if args.template.find('DDBPN') >= 0:
  18. args.model = 'DDBPN'
  19. args.patch_size = 128
  20. args.scale = '4'
  21. args.data_test = 'Set5'
  22. args.batch_size = 20
  23. args.epochs = 1000
  24. args.lr_decay = 500
  25. args.gamma = 0.1
  26. args.weight_decay = 1e-4
  27. args.loss = '1*MSE'
  28. if args.template.find('GAN') >= 0:
  29. args.epochs = 200
  30. args.lr = 5e-5
  31. args.lr_decay = 150

 utility.py

  1. import os
  2. import math
  3. import time
  4. import datetime
  5. from functools import reduce
  6. import matplotlib
  7. matplotlib.use('Agg')
  8. import matplotlib.pyplot as plt
  9. import numpy as np
  10. import scipy.misc as misc
  11. import torch
  12. import torch.optim as optim
  13. import torch.optim.lr_scheduler as lrs
  14. class timer():
  15. def __init__(self):
  16. self.acc = 0
  17. self.tic()
  18. def tic(self):
  19. self.t0 = time.time()
  20. def toc(self):
  21. return time.time() - self.t0
  22. def hold(self):
  23. self.acc += self.toc()
  24. def release(self):
  25. ret = self.acc
  26. self.acc = 0
  27. return ret
  28. def reset(self):
  29. self.acc = 0
  30. class checkpoint():
  31. def __init__(self, args):
  32. self.args = args
  33. self.ok = True
  34. self.log = torch.Tensor()
  35. now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
  36. if args.load == '.':
  37. if args.save == '.': args.save = now
  38. self.dir = '../experiment/' + args.save
  39. else:
  40. self.dir = '../experiment/' + args.load
  41. if not os.path.exists(self.dir):
  42. args.load = '.'
  43. else:
  44. self.log = torch.load(self.dir + '/psnr_log.pt')
  45. print('Continue from epoch {}...'.format(len(self.log)))
  46. if args.reset:
  47. os.system('rm -rf ' + self.dir)
  48. args.load = '.'
  49. def _make_dir(path):
  50. if not os.path.exists(path): os.makedirs(path)
  51. _make_dir(self.dir)
  52. _make_dir(self.dir + '/model')
  53. _make_dir(self.dir + '/results')
  54. open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'
  55. self.log_file = open(self.dir + '/log.txt', open_type)
  56. with open(self.dir + '/config.txt', open_type) as f:
  57. f.write(now + '\n\n')
  58. for arg in vars(args):
  59. f.write('{}: {}\n'.format(arg, getattr(args, arg)))
  60. f.write('\n')
  61. def save(self, trainer, epoch, is_best=False):
  62. trainer.model.save(self.dir, epoch, is_best=is_best)
  63. trainer.loss.save(self.dir)
  64. trainer.loss.plot_loss(self.dir, epoch)
  65. self.plot_psnr(epoch)
  66. torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt'))
  67. torch.save(
  68. trainer.optimizer.state_dict(),
  69. os.path.join(self.dir, 'optimizer.pt')
  70. )
  71. def add_log(self, log):
  72. self.log = torch.cat([self.log, log])
  73. def write_log(self, log, refresh=False):
  74. print(log)
  75. self.log_file.write(log + '\n')
  76. if refresh:
  77. self.log_file.close()
  78. self.log_file = open(self.dir + '/log.txt', 'a')
  79. def done(self):
  80. self.log_file.close()
  81. def plot_psnr(self, epoch):
  82. axis = np.linspace(1, epoch, epoch)
  83. label = 'SR on {}'.format(self.args.data_test)
  84. fig = plt.figure()
  85. plt.title(label)
  86. for idx_scale, scale in enumerate(self.args.scale):
  87. plt.plot(
  88. axis,
  89. self.log[:, idx_scale].numpy(),
  90. label='Scale {}'.format(scale)
  91. )
  92. plt.legend()
  93. plt.xlabel('Epochs')
  94. plt.ylabel('PSNR')
  95. plt.grid(True)
  96. plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test))
  97. plt.close(fig)
  98. def save_results(self, filename, save_list, scale):
  99. filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale)
  100. postfix = ('SR', 'LR', 'HR')
  101. for v, p in zip(save_list, postfix):
  102. normalized = v[0].data.mul(255 / self.args.rgb_range)
  103. ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
  104. misc.imsave('{}{}.png'.format(filename, p), ndarr)
  105. def quantize(img, rgb_range):
  106. pixel_range = 255 / rgb_range
  107. return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
  108. def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
  109. diff = (sr - hr).data.div(rgb_range)
  110. shave = scale
  111. if diff.size(1) > 1:
  112. convert = diff.new(1, 3, 1, 1)
  113. convert[0, 0, 0, 0] = 65.738
  114. convert[0, 1, 0, 0] = 129.057
  115. convert[0, 2, 0, 0] = 25.064
  116. diff.mul_(convert).div_(256)
  117. diff = diff.sum(dim=1, keepdim=True)
  118. '''
  119. if benchmark:
  120. shave = scale
  121. if diff.size(1) > 1:
  122. convert = diff.new(1, 3, 1, 1)
  123. convert[0, 0, 0, 0] = 65.738
  124. convert[0, 1, 0, 0] = 129.057
  125. convert[0, 2, 0, 0] = 25.064
  126. diff.mul_(convert).div_(256)
  127. diff = diff.sum(dim=1, keepdim=True)
  128. else:
  129. shave = scale + 6
  130. '''
  131. valid = diff[:, :, shave:-shave, shave:-shave]
  132. mse = valid.pow(2).mean()
  133. return -10 * math.log10(mse)
  134. def make_optimizer(args, my_model):
  135. trainable = filter(lambda x: x.requires_grad, my_model.parameters())
  136. if args.optimizer == 'SGD':
  137. optimizer_function = optim.SGD
  138. kwargs = {'momentum': args.momentum}
  139. elif args.optimizer == 'ADAM':
  140. optimizer_function = optim.Adam
  141. kwargs = {
  142. 'betas': (args.beta1, args.beta2),
  143. 'eps': args.epsilon
  144. }
  145. elif args.optimizer == 'RMSprop':
  146. optimizer_function = optim.RMSprop
  147. kwargs = {'eps': args.epsilon}
  148. kwargs['lr'] = args.lr
  149. kwargs['weight_decay'] = args.weight_decay
  150. return optimizer_function(trainable, **kwargs)
  151. def make_scheduler(args, my_optimizer):
  152. if args.decay_type == 'step':
  153. scheduler = lrs.StepLR(
  154. my_optimizer,
  155. step_size=args.lr_decay,
  156. gamma=args.gamma
  157. )
  158. elif args.decay_type.find('step') >= 0:
  159. milestones = args.decay_type.split('_')
  160. milestones.pop(0)
  161. milestones = list(map(lambda x: int(x), milestones))
  162. scheduler = lrs.MultiStepLR(
  163. my_optimizer,
  164. milestones=milestones,
  165. gamma=args.gamma
  166. )
  167. return scheduler

 trainer.py

  1. import os
  2. import math
  3. from decimal import Decimal
  4. import utility
  5. import torch
  6. from torch.autograd import Variable
  7. from tqdm import tqdm
  8. class Trainer():
  9. def __init__(self, args, loader, my_model, my_loss, ckp):
  10. self.args = args
  11. self.scale = args.scale
  12. self.ckp = ckp
  13. self.loader_train = loader.loader_train
  14. self.loader_test = loader.loader_test
  15. self.model = my_model
  16. self.loss = my_loss
  17. self.optimizer = utility.make_optimizer(args, self.model)
  18. self.scheduler = utility.make_scheduler(args, self.optimizer)
  19. if self.args.load != '.':
  20. self.optimizer.load_state_dict(
  21. torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
  22. )
  23. for _ in range(len(ckp.log)): self.scheduler.step()
  24. self.error_last = 1e8
  25. def train(self):
  26. self.scheduler.step()
  27. self.loss.step()
  28. epoch = self.scheduler.last_epoch + 1
  29. lr = self.scheduler.get_lr()[0]
  30. self.ckp.write_log(
  31. '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
  32. )
  33. self.loss.start_log()
  34. self.model.train()
  35. timer_data, timer_model = utility.timer(), utility.timer()
  36. for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train):
  37. lr, hr = self.prepare([lr, hr])
  38. timer_data.hold()
  39. timer_model.tic()
  40. self.optimizer.zero_grad()
  41. sr = self.model(lr, idx_scale)
  42. loss = self.loss(sr, hr)
  43. if loss.item() < self.args.skip_threshold * self.error_last:
  44. loss.backward()
  45. self.optimizer.step()
  46. else:
  47. print('Skip this batch {}! (Loss: {})'.format(
  48. batch + 1, loss.item()
  49. ))
  50. timer_model.hold()
  51. if (batch + 1) % self.args.print_every == 0:
  52. self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
  53. (batch + 1) * self.args.batch_size,
  54. len(self.loader_train.dataset),
  55. self.loss.display_loss(batch),
  56. timer_model.release(),
  57. timer_data.release()))
  58. timer_data.tic()
  59. self.loss.end_log(len(self.loader_train))
  60. self.error_last = self.loss.log[-1, -1]
  61. def test(self):
  62. epoch = self.scheduler.last_epoch + 1
  63. self.ckp.write_log('\nEvaluation:')
  64. self.ckp.add_log(torch.zeros(1, len(self.scale)))
  65. self.model.eval()
  66. timer_test = utility.timer()
  67. with torch.no_grad():
  68. for idx_scale, scale in enumerate(self.scale):
  69. eval_acc = 0
  70. self.loader_test.dataset.set_scale(idx_scale)
  71. tqdm_test = tqdm(self.loader_test, ncols=80)
  72. for idx_img, (lr, hr, filename, _) in enumerate(tqdm_test):
  73. filename = filename[0]
  74. no_eval = (hr.nelement() == 1)
  75. if not no_eval:
  76. lr, hr = self.prepare([lr, hr])
  77. else:
  78. lr = self.prepare([lr])[0]
  79. sr = self.model(lr, idx_scale)
  80. sr = utility.quantize(sr, self.args.rgb_range)
  81. save_list = [sr]
  82. if not no_eval:
  83. eval_acc += utility.calc_psnr(
  84. sr, hr, scale, self.args.rgb_range,
  85. benchmark=self.loader_test.dataset.benchmark
  86. )
  87. save_list.extend([lr, hr])
  88. if self.args.save_results:
  89. self.ckp.save_results(filename, save_list, scale)
  90. self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test)
  91. best = self.ckp.log.max(0)
  92. self.ckp.write_log(
  93. '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
  94. self.args.data_test,
  95. scale,
  96. self.ckp.log[-1, idx_scale],
  97. best[0][idx_scale],
  98. best[1][idx_scale] + 1
  99. )
  100. )
  101. self.ckp.write_log(
  102. 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
  103. )
  104. if not self.args.test_only:
  105. self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch))
  106. def prepare(self, l, volatile=False):
  107. device = torch.device('cpu' if self.args.cpu else 'cuda')
  108. def _prepare(tensor):
  109. if self.args.precision == 'half': tensor = tensor.half()
  110. return tensor.to(device)
  111. return [_prepare(_l) for _l in l]
  112. def terminate(self):
  113. if self.args.test_only:
  114. self.test()
  115. return True
  116. else:
  117. epoch = self.scheduler.last_epoch + 1
  118. return epoch >= self.args.epochs

三.测试网络

利用模型将图片四倍放大的结果如下:

输入图片:

    

输出图片:

  

输入图片:

  

输出图片:

  

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/运维做开发/article/detail/782834
推荐阅读
相关标签
  

闽ICP备14008679号