赞
踩
目录
RCAN:Residual Channel Attention Network(残差通道注意网络 )
卷积神经网络(CNN)的深度对于图像超分辨率(SR)是极其关键的因素。然而,我们观察到,更深层次的图像SR网络更难训练。低分辨率的输入和特征包含丰富的低频信息,这些信息在通道间被平等对待,从而阻碍了CNNs的表征能力。为了解决这些问题,我们提出了一种非常深的残差通道注意网络(RCAN)。具体地,我们提出了一种residual in residual(RIR)结构来形成非常深的网络,它由几个具有长跳连接的残差组组成。每个残差组包含一些具有短跳连接的残差块。与此同时,RIR允许大量的低频信息通过多个跳跃连接被绕过,使得主网络专注于学习高频信息。在此基础上,我们提出了一种通道注意机制,通过考虑通道间的相互依赖关系,自适应地重新调整通道特征。大量的实验表明,与比之前最先进的方法相比,我们的RCAN实现了更好的精度和视觉效果。
背景:
解决方案:
网络架构:
RCAN主要由四个部分组成:浅层特征提取、残差嵌套(RIR)深度特征提取、上采样模块和重建部分。
名词解释:
论文地址:
参考文章:
源码地址:
_init_.py
- from importlib import import_module
-
- from dataloader import MSDataLoader
- from torch.utils.data.dataloader import default_collate
-
- class Data:
- def __init__(self, args):
- kwargs = {}
- # 如果不在cpu上训练
- if not args.cpu:
- kwargs['collate_fn'] = default_collate
- kwargs['pin_memory'] = True
- # 在cpu上训练
- else:
- kwargs['collate_fn'] = default_collate
- kwargs['pin_memory'] = False
-
- self.loader_train = None
- if not args.test_only:
- # .lower()将大写字母转换为小写字母
- module_train = import_module('data.' + args.data_train.lower())
- # getattr() 函数用于返回一个对象属性值。
- trainset = getattr(module_train, args.data_train)(args)
- self.loader_train = MSDataLoader(
- args,
- trainset,
- batch_size=args.batch_size,
- shuffle=True,
- **kwargs
- )
-
- # 针对特殊的数据
- if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100']:
- if not args.benchmark_noise:
- module_test = import_module('data.benchmark')
- testset = getattr(module_test, 'Benchmark')(args, train=False)
- else:
- module_test = import_module('data.benchmark_noise')
- testset = getattr(module_test, 'BenchmarkNoise')(
- args,
- train=False
- )
-
- else:
- module_test = import_module('data.' + args.data_test.lower())
- testset = getattr(module_test, args.data_test)(args, train=False)
-
- # 对于自定义的MSDataLoader,主要需要传入的参数为args和dataset
- self.loader_test = MSDataLoader(
- args,
- testset,
- batch_size=1,
- shuffle=False,
- **kwargs
- )
-
- '''
- class MSDataLoader(DataLoader):
- def __init__(
- self, args, dataset, batch_size=1, shuffle=False,
- sampler=None, batch_sampler=None,
- collate_fn=default_collate, pin_memory=False, drop_last=False,
- timeout=0, worker_init_fn=None):
- super(MSDataLoader, self).__init__(
- dataset, batch_size=batch_size, shuffle=shuffle,
- sampler=sampler, batch_sampler=batch_sampler,
- num_workers=args.n_threads, collate_fn=collate_fn,
- pin_memory=pin_memory, drop_last=drop_last,
- timeout=timeout, worker_init_fn=worker_init_fn)
- self.scale = args.scale
- def __iter__(self):
- return _MSDataLoaderIter(self)
- '''
benchmark.py
- import os
-
- from data import common
- from data import srdata
-
- import numpy as np
- import scipy.misc as misc
-
- import torch
- import torch.utils.data as data
-
- class Benchmark(srdata.SRData):
- def __init__(self, args, train=True):
- super(Benchmark, self).__init__(args, train, benchmark=True)
-
- # 扫描磁盘得到数据
- def _scan(self):
- list_hr = []
- list_lr = [[] for _ in self.scale]
- for entry in os.scandir(self.dir_hr):
- # os.path.splitext分离文件名字和文件类型
- # eg: os.path.splitext(abc.txt) 得到的为('abc','txt')
- # filename取出的是文件名
- filename = os.path.splitext(entry.name)[0]
- # filename + self.ext 为文件的完整名字
- # os.path.join用于拼接文件路径,可以传入多个路径
- # 此处append的文件路径即为 self.dir_hr + (filename+self.ext)
- list_hr.append(os.path.join(self.dir_hr, filename + self.ext))
- for si, s in enumerate(self.scale):
- list_lr[si].append(os.path.join(
- self.dir_lr,
- 'X{}/{}x{}{}'.format(s, filename, s, self.ext)
- ))
-
- # 对取出的数据进行升序排列
- list_hr.sort()
- for l in list_lr:
- l.sort()
-
- return list_hr, list_lr
-
- # 设置数据的地址以及数据的类型
- def _set_filesystem(self, dir_data):
- self.apath = os.path.join(dir_data, 'benchmark', self.args.data_test)
- self.dir_hr = os.path.join(self.apath, 'HR')
- self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
- self.ext = '.png'
common.py
- import random
-
- import numpy as np
- import skimage.io as sio
- import skimage.color as sc
- import skimage.transform as st
-
- import torch
- from torchvision import transforms
-
- def get_patch(img_in, img_tar, patch_size, scale, multi_scale=False):
- # shape得到图片的高度、宽度、颜色通道
- # 所以shape[:2}就是获取图片的前两个维度,获得图片的高度和宽度
-
- ih, iw = img_in.shape[:2]
-
- p = scale if multi_scale else 1
- tp = p * patch_size
- ip = tp // scale
-
- ix = random.randrange(0, iw - ip + 1)
- iy = random.randrange(0, ih - ip + 1)
- tx, ty = scale * ix, scale * iy
-
- img_in = img_in[iy:iy + ip, ix:ix + ip, :]
- img_tar = img_tar[ty:ty + tp, tx:tx + tp, :]
-
- return img_in, img_tar
-
- # 设置channel值
- def set_channel(l, n_channel):
- def _set_channel(img):
- if img.ndim == 2:
- # expand_dims(a, axis)中,a为numpy数组,axis为需添加维度的轴
- # 使数据增加一个维度
- img = np.expand_dims(img, axis=2)
-
- c = img.shape[2]
- if n_channel == 1 and c == 3:
- img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
- elif n_channel == 3 and c == 1:
- # numpy.concatenate((a1,a2,...), axis=0)函数。
- # 能 够一次完成多个数组的拼接。其中a1,a2,...是数组类型的参数
- img = np.concatenate([img] * n_channel, 2)
-
- return img
-
- return [_set_channel(_l) for _l in l]
-
- # 将np.array类型转为tensor类型
- def np2Tensor(l, rgb_range):
- def _np2Tensor(img):
- # ascontiguousarray函数将一个内存不连续存储的数组转换为内存连续存储的数组,使得运行速度更快
- # img.transpose((2,0,1))将图片的维度由(0,1,2)转换为(2,0,1)
- np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
- tensor = torch.from_numpy(np_transpose).float()
- tensor.mul_(rgb_range / 255)
-
- return tensor
-
- return [_np2Tensor(_l) for _l in l]
-
- def add_noise(x, noise='.'):
- if noise is not '.':
- noise_type = noise[0]
- noise_value = int(noise[1:])
- if noise_type == 'G':
- noises = np.random.normal(scale=noise_value, size=x.shape)
- noises = noises.round()
- elif noise_type == 'S':
- noises = np.random.poisson(x * noise_value) / noise_value
- noises = noises - noises.mean(axis=0).mean(axis=0)
-
- x_noise = x.astype(np.int16) + noises.astype(np.int16)
- x_noise = x_noise.clip(0, 255).astype(np.uint8)
- return x_noise
- else:
- return x
-
- def augment(l, hflip=True, rot=True):
- hflip = hflip and random.random() < 0.5
- vflip = rot and random.random() < 0.5
- rot90 = rot and random.random() < 0.5
-
- def _augment(img):
- if hflip: img = img[:, ::-1, :]
- if vflip: img = img[::-1, :, :]
- if rot90: img = img.transpose(1, 0, 2)
-
- return img
-
- return [_augment(_l) for _l in l]
demo.py
- import os
-
- from data import common
-
- import numpy as np
- import scipy.misc as misc
-
- import torch
- import torch.utils.data as data
-
- class Demo(data.Dataset):
- def __init__(self, args, train=False):
- self.args = args
- self.name = 'Demo'
- self.scale = args.scale
- self.idx_scale = 0
- self.train = False
- self.benchmark = False
-
- self.filelist = []
- for f in os.listdir(args.dir_demo):
- if f.find('.png') >= 0 or f.find('.jp') >= 0:
- self.filelist.append(os.path.join(args.dir_demo, f))
- self.filelist.sort()
-
- def __getitem__(self, idx):
- filename = os.path.split(self.filelist[idx])[-1]
- filename, _ = os.path.splitext(filename)
- lr = misc.imread(self.filelist[idx])
- lr = common.set_channel([lr], self.args.n_colors)[0]
-
- return common.np2Tensor([lr], self.args.rgb_range)[0], -1, filename
-
- def __len__(self):
- return len(self.filelist)
-
- def set_scale(self, idx_scale):
- self.idx_scale = idx_scale
-
srdata.py
- import os
-
- from data import common
-
- import numpy as np
- import scipy.misc as misc
-
- import torch
- import torch.utils.data as data
-
- class SRData(data.Dataset):
- def __init__(self, args, train=True, benchmark=False):
- self.args = args
- self.train = train
- self.split = 'train' if train else 'test'
- self.benchmark = benchmark
- self.scale = args.scale
- self.idx_scale = 0
-
- self._set_filesystem(args.dir_data)
-
- def _load_bin():
- self.images_hr = np.load(self._name_hrbin())
- self.images_lr = [
- np.load(self._name_lrbin(s)) for s in self.scale
- ]
-
- if args.ext == 'img' or benchmark:
- self.images_hr, self.images_lr = self._scan()
- elif args.ext.find('sep') >= 0:
- self.images_hr, self.images_lr = self._scan()
- if args.ext.find('reset') >= 0:
- print('Preparing seperated binary files')
- for v in self.images_hr:
- hr = misc.imread(v)
- name_sep = v.replace(self.ext, '.npy')
- np.save(name_sep, hr)
- for si, s in enumerate(self.scale):
- for v in self.images_lr[si]:
- lr = misc.imread(v)
- name_sep = v.replace(self.ext, '.npy')
- np.save(name_sep, lr)
-
- self.images_hr = [
- v.replace(self.ext, '.npy') for v in self.images_hr
- ]
- self.images_lr = [
- [v.replace(self.ext, '.npy') for v in self.images_lr[i]]
- for i in range(len(self.scale))
- ]
-
- elif args.ext.find('bin') >= 0:
- try:
- if args.ext.find('reset') >= 0:
- raise IOError
- print('Loading a binary file')
- _load_bin()
- except:
- print('Preparing a binary file')
- bin_path = os.path.join(self.apath, 'bin')
- if not os.path.isdir(bin_path):
- os.mkdir(bin_path)
-
- list_hr, list_lr = self._scan()
- hr = [misc.imread(f) for f in list_hr]
- np.save(self._name_hrbin(), hr)
- del hr
- for si, s in enumerate(self.scale):
- lr_scale = [misc.imread(f) for f in list_lr[si]]
- np.save(self._name_lrbin(s), lr_scale)
- del lr_scale
- _load_bin()
- else:
- print('Please define data type')
-
- def _scan(self):
- raise NotImplementedError
-
- def _set_filesystem(self, dir_data):
- raise NotImplementedError
-
- def _name_hrbin(self):
- raise NotImplementedError
-
- def _name_lrbin(self, scale):
- raise NotImplementedError
-
- def __getitem__(self, idx):
- lr, hr, filename = self._load_file(idx)
- lr, hr = self._get_patch(lr, hr)
- lr, hr = common.set_channel([lr, hr], self.args.n_colors)
- lr_tensor, hr_tensor = common.np2Tensor([lr, hr], self.args.rgb_range)
- return lr_tensor, hr_tensor, filename
-
- def __len__(self):
- return len(self.images_hr)
-
- def _get_index(self, idx):
- return idx
-
- def _load_file(self, idx):
- idx = self._get_index(idx)
- lr = self.images_lr[self.idx_scale][idx]
- hr = self.images_hr[idx]
- if self.args.ext == 'img' or self.benchmark:
- filename = hr
- lr = misc.imread(lr)
- hr = misc.imread(hr)
- elif self.args.ext.find('sep') >= 0:
- filename = hr
- lr = np.load(lr)
- hr = np.load(hr)
- else:
- filename = str(idx + 1)
-
- filename = os.path.splitext(os.path.split(filename)[-1])[0]
-
- return lr, hr, filename
-
- def _get_patch(self, lr, hr):
- patch_size = self.args.patch_size
- scale = self.scale[self.idx_scale]
- multi_scale = len(self.scale) > 1
- if self.train:
- lr, hr = common.get_patch(
- lr, hr, patch_size, scale, multi_scale=multi_scale
- )
- lr, hr = common.augment([lr, hr])
- lr = common.add_noise(lr, self.args.noise)
- else:
- ih, iw = lr.shape[0:2]
- hr = hr[0:ih * scale, 0:iw * scale]
-
- return lr, hr
-
- def set_scale(self, idx_scale):
- self.idx_scale = idx_scale
-
div2k.py
- import os
-
- from data import common
- from data import srdata
-
- import numpy as np
- import scipy.misc as misc
-
- import torch
- import torch.utils.data as data
-
- class DIV2K(srdata.SRData):
- def __init__(self, args, train=True):
- super(DIV2K, self).__init__(args, train)
- self.repeat = args.test_every // (args.n_train // args.batch_size)
-
- def _scan(self):
- list_hr = []
- list_lr = [[] for _ in self.scale]
- if self.train:
- idx_begin = 0
- idx_end = self.args.n_train
- else:
- idx_begin = self.args.n_train
- idx_end = self.args.offset_val + self.args.n_val
-
- for i in range(idx_begin + 1, idx_end + 1):
- filename = '{:0>4}'.format(i)
- list_hr.append(os.path.join(self.dir_hr, filename + self.ext))
- for si, s in enumerate(self.scale):
- list_lr[si].append(os.path.join(
- self.dir_lr,
- 'X{}/{}x{}{}'.format(s, filename, s, self.ext)
- ))
-
- return list_hr, list_lr
-
- def _set_filesystem(self, dir_data):
- self.apath = dir_data + '/DIV2K'
- self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
- self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
- self.ext = '.png'
-
- def _name_hrbin(self):
- return os.path.join(
- self.apath,
- 'bin',
- '{}_bin_HR.npy'.format(self.split)
- )
-
- def _name_lrbin(self, scale):
- return os.path.join(
- self.apath,
- 'bin',
- '{}_bin_LR_X{}.npy'.format(self.split, scale)
- )
-
- def __len__(self):
- if self.train:
- return len(self.images_hr) * self.repeat
- else:
- return len(self.images_hr)
-
- def _get_index(self, idx):
- if self.train:
- return idx % len(self.images_hr)
- else:
- return idx
-
_init_.py
- import os
- from importlib import import_module
-
- import matplotlib
- matplotlib.use('Agg')
- import matplotlib.pyplot as plt
-
- import numpy as np
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- class Loss(nn.modules.loss._Loss):
- def __init__(self, args, ckp):
- super(Loss, self).__init__()
- print('Preparing loss function:')
-
- self.n_GPUs = args.n_GPUs
- self.loss = []
- # 首先说说 nn.ModuleList 这个类,你可以把任意 nn.Module 的子类
- # (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,
- # 方法和 Python 自带的 list 一样,无非是 extend,append 等操作。
- # 但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是会自动注册到整个网络上的,
- # 同时 module 的 parameters 也会自动添加到整个网络中。
- self.loss_module = nn.ModuleList()
- # split(' ')根据括号里的字符分割字符串
- for loss in args.loss.split('+'):
- weight, loss_type = loss.split('*')
- if loss_type == 'MSE':
- loss_function = nn.MSELoss()
- elif loss_type == 'L1':
- loss_function = nn.L1Loss()
- elif loss_type.find('VGG') >= 0:
- module = import_module('loss.vgg')
- loss_function = getattr(module, 'VGG')(
- loss_type[3:],
- rgb_range=args.rgb_range
- )
- elif loss_type.find('GAN') >= 0:
- module = import_module('loss.adversarial')
- loss_function = getattr(module, 'Adversarial')(
- args,
- loss_type
- )
-
- self.loss.append({
- 'type': loss_type,
- 'weight': float(weight),
- 'function': loss_function}
- )
- if loss_type.find('GAN') >= 0:
- self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
-
- if len(self.loss) > 1:
- self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
-
- for l in self.loss:
- if l['function'] is not None:
- print('{:.3f} * {}'.format(l['weight'], l['type']))
- self.loss_module.append(l['function'])
-
- self.log = torch.Tensor()
-
- device = torch.device('cpu' if args.cpu else 'cuda')
- self.loss_module.to(device)
- if args.precision == 'half': self.loss_module.half()
- if not args.cpu and args.n_GPUs > 1:
- self.loss_module = nn.DataParallel(
- self.loss_module, range(args.n_GPUs)
- )
-
- if args.load != '.': self.load(ckp.dir, cpu=args.cpu)
-
- def forward(self, sr, hr):
- losses = []
- for i, l in enumerate(self.loss):
- if l['function'] is not None:
- loss = l['function'](sr, hr)
- effective_loss = l['weight'] * loss
- losses.append(effective_loss)
- self.log[-1, i] += effective_loss.item()
- elif l['type'] == 'DIS':
- self.log[-1, i] += self.loss[i - 1]['function'].loss
-
- loss_sum = sum(losses)
- if len(self.loss) > 1:
- self.log[-1, -1] += loss_sum.item()
-
- return loss_sum
-
- def step(self):
- for l in self.get_loss_module():
- if hasattr(l, 'scheduler'):
- l.scheduler.step()
-
- def start_log(self):
- self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
-
- def end_log(self, n_batches):
- self.log[-1].div_(n_batches)
-
- def display_loss(self, batch):
- n_samples = batch + 1
- log = []
- for l, c in zip(self.loss, self.log[-1]):
- log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
-
- return ''.join(log)
-
- def plot_loss(self, apath, epoch):
- axis = np.linspace(1, epoch, epoch)
- for i, l in enumerate(self.loss):
- label = '{} Loss'.format(l['type'])
- fig = plt.figure()
- plt.title(label)
- plt.plot(axis, self.log[:, i].numpy(), label=label)
- plt.legend()
- plt.xlabel('Epochs')
- plt.ylabel('Loss')
- plt.grid(True)
- plt.savefig('{}/loss_{}.pdf'.format(apath, l['type']))
- plt.close(fig)
-
- def get_loss_module(self):
- if self.n_GPUs == 1:
- return self.loss_module
- else:
- return self.loss_module.module
-
- def save(self, apath):
- torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
- torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
-
- def load(self, apath, cpu=False):
- if cpu:
- kwargs = {'map_location': lambda storage, loc: storage}
- else:
- kwargs = {}
-
- self.load_state_dict(torch.load(
- os.path.join(apath, 'loss.pt'),
- **kwargs
- ))
- self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
- for l in self.get_loss_module():
- if hasattr(l, 'scheduler'):
- for _ in range(len(self.log)): l.scheduler.step()
-
adversarial.py
- import utility
- from model import common
- from loss import discriminator
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.optim as optim
- from torch.autograd import Variable
-
- class Adversarial(nn.Module):
- def __init__(self, args, gan_type):
- super(Adversarial, self).__init__()
- self.gan_type = gan_type
- self.gan_k = args.gan_k
- self.discriminator = discriminator.Discriminator(args, gan_type)
- if gan_type != 'WGAN_GP':
- self.optimizer = utility.make_optimizer(args, self.discriminator)
- else:
- self.optimizer = optim.Adam(
- self.discriminator.parameters(),
- betas=(0, 0.9), eps=1e-8, lr=1e-5
- )
- self.scheduler = utility.make_scheduler(args, self.optimizer)
-
- def forward(self, fake, real):
- fake_detach = fake.detach()
-
- self.loss = 0
- for _ in range(self.gan_k):
- self.optimizer.zero_grad()
- d_fake = self.discriminator(fake_detach)
- d_real = self.discriminator(real)
- if self.gan_type == 'GAN':
- label_fake = torch.zeros_like(d_fake)
- label_real = torch.ones_like(d_real)
- loss_d \
- = F.binary_cross_entropy_with_logits(d_fake, label_fake) \
- + F.binary_cross_entropy_with_logits(d_real, label_real)
- elif self.gan_type.find('WGAN') >= 0:
- loss_d = (d_fake - d_real).mean()
- if self.gan_type.find('GP') >= 0:
- epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
- hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
- hat.requires_grad = True
- d_hat = self.discriminator(hat)
- gradients = torch.autograd.grad(
- outputs=d_hat.sum(), inputs=hat,
- retain_graph=True, create_graph=True, only_inputs=True
- )[0]
- gradients = gradients.view(gradients.size(0), -1)
- gradient_norm = gradients.norm(2, dim=1)
- gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
- loss_d += gradient_penalty
-
- # Discriminator update
- self.loss += loss_d.item()
- loss_d.backward()
- self.optimizer.step()
-
- if self.gan_type == 'WGAN':
- for p in self.discriminator.parameters():
- p.data.clamp_(-1, 1)
-
- self.loss /= self.gan_k
-
- d_fake_for_g = self.discriminator(fake)
- if self.gan_type == 'GAN':
- loss_g = F.binary_cross_entropy_with_logits(
- d_fake_for_g, label_real
- )
- elif self.gan_type.find('WGAN') >= 0:
- loss_g = -d_fake_for_g.mean()
-
- # Generator loss
- return loss_g
-
- def state_dict(self, *args, **kwargs):
- state_discriminator = self.discriminator.state_dict(*args, **kwargs)
- state_optimizer = self.optimizer.state_dict()
-
- return dict(**state_discriminator, **state_optimizer)
-
- # Some references
- # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
- # OR
- # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
discriminator.py
- from model import common
-
- import torch.nn as nn
-
- class Discriminator(nn.Module):
- def __init__(self, args, gan_type='GAN'):
- super(Discriminator, self).__init__()
-
- in_channels = 3
- out_channels = 64
- depth = 7
- #bn = not gan_type == 'WGAN_GP'
- bn = True
- act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- m_features = [
- common.BasicBlock(args.n_colors, out_channels, 3, bn=bn, act=act)
- ]
- for i in range(depth):
- in_channels = out_channels
- if i % 2 == 1:
- stride = 1
- out_channels *= 2
- else:
- stride = 2
- m_features.append(common.BasicBlock(
- in_channels, out_channels, 3, stride=stride, bn=bn, act=act
- ))
-
- self.features = nn.Sequential(*m_features)
-
- patch_size = args.patch_size // (2**((depth + 1) // 2))
- m_classifier = [
- nn.Linear(out_channels * patch_size**2, 1024),
- act,
- nn.Linear(1024, 1)
- ]
- self.classifier = nn.Sequential(*m_classifier)
-
- def forward(self, x):
- features = self.features(x)
- output = self.classifier(features.view(features.size(0), -1))
-
- return output
-
vgg.py
- from model import common
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torchvision.models as models
- from torch.autograd import Variable
-
- class VGG(nn.Module):
- def __init__(self, conv_index, rgb_range=1):
- super(VGG, self).__init__()
- # pretrained = True 表示使用已经训练过的参数
- vgg_features = models.vgg19(pretrained=True).features
- modules = [m for m in vgg_features]
- if conv_index == '22':
- self.vgg = nn.Sequential(*modules[:8])
- elif conv_index == '54':
- self.vgg = nn.Sequential(*modules[:35])
-
- vgg_mean = (0.485, 0.456, 0.406)
- vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
- self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
- self.vgg.requires_grad = False
-
- def forward(self, sr, hr):
- def _forward(x):
- x = self.sub_mean(x)
- x = self.vgg(x)
- return x
-
- vgg_sr = _forward(sr)
- with torch.no_grad():
- vgg_hr = _forward(hr.detach())
-
- loss = F.mse_loss(vgg_sr, vgg_hr)
-
- return loss
dataloader.py
- import sys
- import threading
- import queue
- import random
- import collections
-
- import torch
- import torch.multiprocessing as multiprocessing
-
- from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
- _remove_worker_pids, _error_if_any_worker_fails
- from torch.utils.data.dataloader import DataLoader
- from torch.utils.data.dataloader import _DataLoaderIter
-
- from torch.utils.data.dataloader import ExceptionWrapper
- from torch.utils.data.dataloader import _use_shared_memory
- from torch.utils.data.dataloader import _worker_manager_loop
- from torch.utils.data.dataloader import numpy_type_map
- from torch.utils.data.dataloader import default_collate
- from torch.utils.data.dataloader import pin_memory_batch
- from torch.utils.data.dataloader import _SIGCHLD_handler_set
- from torch.utils.data.dataloader import _set_SIGCHLD_handler
-
- if sys.version_info[0] == 2:
- import Queue as queue
- else:
- import queue
-
- def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
- global _use_shared_memory
- _use_shared_memory = True
- _set_worker_signal_handlers()
-
- torch.set_num_threads(1)
- torch.manual_seed(seed)
- while True:
- r = index_queue.get()
- if r is None:
- break
- idx, batch_indices = r
- try:
- idx_scale = 0
- if len(scale) > 1 and dataset.train:
- idx_scale = random.randrange(0, len(scale))
- dataset.set_scale(idx_scale)
-
- samples = collate_fn([dataset[i] for i in batch_indices])
- samples.append(idx_scale)
-
- except Exception:
- data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
- else:
- data_queue.put((idx, samples))
-
- class _MSDataLoaderIter(_DataLoaderIter):
- def __init__(self, loader):
- self.dataset = loader.dataset
- self.scale = loader.scale
- self.collate_fn = loader.collate_fn
- self.batch_sampler = loader.batch_sampler
- self.num_workers = loader.num_workers
- self.pin_memory = loader.pin_memory and torch.cuda.is_available()
- self.timeout = loader.timeout
- self.done_event = threading.Event()
-
- self.sample_iter = iter(self.batch_sampler)
-
- if self.num_workers > 0:
- self.worker_init_fn = loader.worker_init_fn
- self.index_queues = [
- multiprocessing.Queue() for _ in range(self.num_workers)
- ]
- self.worker_queue_idx = 0
- self.worker_result_queue = multiprocessing.SimpleQueue()
- self.batches_outstanding = 0
- self.worker_pids_set = False
- self.shutdown = False
- self.send_idx = 0
- self.rcvd_idx = 0
- self.reorder_dict = {}
-
- base_seed = torch.LongTensor(1).random_()[0]
- self.workers = [
- multiprocessing.Process(
- target=_ms_loop,
- args=(
- self.dataset,
- self.index_queues[i],
- self.worker_result_queue,
- self.collate_fn,
- self.scale,
- base_seed + i,
- self.worker_init_fn,
- i
- )
- )
- for i in range(self.num_workers)]
-
- if self.pin_memory or self.timeout > 0:
- self.data_queue = queue.Queue()
- if self.pin_memory:
- maybe_device_id = torch.cuda.current_device()
- else:
- # do not initialize cuda context if not necessary
- maybe_device_id = None
- self.worker_manager_thread = threading.Thread(
- target=_worker_manager_loop,
- args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
- maybe_device_id))
- self.worker_manager_thread.daemon = True
- self.worker_manager_thread.start()
- else:
- self.data_queue = self.worker_result_queue
-
- for w in self.workers:
- w.daemon = True # ensure that the worker exits on process exit
- w.start()
-
- _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
- _set_SIGCHLD_handler()
- self.worker_pids_set = True
-
- # prime the prefetch loop
- for _ in range(2 * self.num_workers):
- self._put_indices()
-
- class MSDataLoader(DataLoader):
- def __init__(
- self, args, dataset, batch_size=1, shuffle=False,
- sampler=None, batch_sampler=None,
- collate_fn=default_collate, pin_memory=False, drop_last=False,
- timeout=0, worker_init_fn=None):
-
- super(MSDataLoader, self).__init__(
- dataset, batch_size=batch_size, shuffle=shuffle,
- sampler=sampler, batch_sampler=batch_sampler,
- num_workers=args.n_threads, collate_fn=collate_fn,
- pin_memory=pin_memory, drop_last=drop_last,
- timeout=timeout, worker_init_fn=worker_init_fn)
-
- self.scale = args.scale
-
- def __iter__(self):
- return _MSDataLoaderIter(self)
main.py
- import torch
-
- import utility
- import data
- import model
- import loss
- from option import args
- from trainer import Trainer
-
- torch.manual_seed(args.seed)
- checkpoint = utility.checkpoint(args)
-
- if checkpoint.ok:
- loader = data.Data(args)
- model = model.Model(args, checkpoint)
- loss = loss.Loss(args, checkpoint) if not args.test_only else None
- t = Trainer(args, loader, model, loss, checkpoint)
- while not t.terminate():
- t.train()
- t.test()
-
- checkpoint.done()
-
option.py
- import argparse
- import template
-
- parser = argparse.ArgumentParser(description='EDSR and MDSR')
-
- parser.add_argument('--debug', action='store_true',
- help='Enables debug mode')
- parser.add_argument('--template', default='.',
- help='You can set various templates in option.py')
-
- # Hardware specifications
- parser.add_argument('--n_threads', type=int, default=3,
- help='number of threads for data loading')
- parser.add_argument('--cpu', action='store_true',
- help='use cpu only')
- parser.add_argument('--n_GPUs', type=int, default=1,
- help='number of GPUs')
- parser.add_argument('--seed', type=int, default=1,
- help='random seed')
-
- # Data specifications
- parser.add_argument('--dir_data', type=str, default='/home/yulun/data/SR/traindata/DIV2K/bicubic',
- help='dataset directory')
- parser.add_argument('--dir_demo', type=str, default='../test',
- help='demo image directory')
- parser.add_argument('--data_train', type=str, default='DIV2K',
- help='train dataset name')
- parser.add_argument('--data_test', type=str, default='DIV2K',
- help='test dataset name')
- parser.add_argument('--benchmark_noise', action='store_true',
- help='use noisy benchmark sets')
- parser.add_argument('--n_train', type=int, default=800,
- help='number of training set')
- parser.add_argument('--n_val', type=int, default=5,
- help='number of validation set')
- parser.add_argument('--offset_val', type=int, default=800,
- help='validation index offest')
- parser.add_argument('--ext', type=str, default='sep_reset',
- help='dataset file extension')
- parser.add_argument('--scale', default='4',
- help='super resolution scale')
- parser.add_argument('--patch_size', type=int, default=192,
- help='output patch size')
- parser.add_argument('--rgb_range', type=int, default=255,
- help='maximum value of RGB')
- parser.add_argument('--n_colors', type=int, default=3,
- help='number of color channels to use')
- parser.add_argument('--noise', type=str, default='.',
- help='Gaussian noise std.')
- parser.add_argument('--chop', action='store_true',
- help='enable memory-efficient forward')
-
- # Model specifications
- parser.add_argument('--model', default='RCAN',
- help='model name')
-
- parser.add_argument('--act', type=str, default='relu',
- help='activation function')
- parser.add_argument('--pre_train', type=str, default='.',
- help='pre-trained model directory')
- parser.add_argument('--extend', type=str, default='.',
- help='pre-trained model directory')
- parser.add_argument('--n_resblocks', type=int, default=20,
- help='number of residual blocks')
- parser.add_argument('--n_feats', type=int, default=64,
- help='number of feature maps')
- parser.add_argument('--res_scale', type=float, default=1,
- help='residual scaling')
- parser.add_argument('--shift_mean', default=True,
- help='subtract pixel mean from the input')
- parser.add_argument('--precision', type=str, default='single',
- choices=('single', 'half'),
- help='FP precision for test (single | half)')
-
- # Training specifications
- parser.add_argument('--reset', action='store_true',
- help='reset the training')
- parser.add_argument('--test_every', type=int, default=1000,
- help='do test per every N batches')
- parser.add_argument('--epochs', type=int, default=1000,
- help='number of epochs to train')
- parser.add_argument('--batch_size', type=int, default=16,
- help='input batch size for training')
- parser.add_argument('--split_batch', type=int, default=1,
- help='split the batch into smaller chunks')
- parser.add_argument('--self_ensemble', action='store_true',
- help='use self-ensemble method for test')
- parser.add_argument('--test_only', action='store_true',
- help='set this option to test the model')
- parser.add_argument('--gan_k', type=int, default=1,
- help='k value for adversarial loss')
-
- # Optimization specifications
- parser.add_argument('--lr', type=float, default=1e-4,
- help='learning rate')
- parser.add_argument('--lr_decay', type=int, default=200,
- help='learning rate decay per N epochs')
- parser.add_argument('--decay_type', type=str, default='step',
- help='learning rate decay type')
- parser.add_argument('--gamma', type=float, default=0.5,
- help='learning rate decay factor for step decay')
- parser.add_argument('--optimizer', default='ADAM',
- choices=('SGD', 'ADAM', 'RMSprop'),
- help='optimizer to use (SGD | ADAM | RMSprop)')
- parser.add_argument('--momentum', type=float, default=0.9,
- help='SGD momentum')
- parser.add_argument('--beta1', type=float, default=0.9,
- help='ADAM beta1')
- parser.add_argument('--beta2', type=float, default=0.999,
- help='ADAM beta2')
- parser.add_argument('--epsilon', type=float, default=1e-8,
- help='ADAM epsilon for numerical stability')
- parser.add_argument('--weight_decay', type=float, default=0,
- help='weight decay')
-
- # Loss specifications
- parser.add_argument('--loss', type=str, default='1*L1',
- help='loss function configuration')
- parser.add_argument('--skip_threshold', type=float, default='1e6',
- help='skipping batch that has large error')
-
- # Log specifications
- parser.add_argument('--save', type=str, default='test',
- help='file name to save')
- parser.add_argument('--load', type=str, default='.',
- help='file name to load')
- parser.add_argument('--resume', type=int, default=0,
- help='resume from specific checkpoint')
- parser.add_argument('--print_model', action='store_true',
- help='print model')
- parser.add_argument('--save_models', action='store_true',
- help='save all intermediate models')
- parser.add_argument('--print_every', type=int, default=100,
- help='how many batches to wait before logging training status')
- parser.add_argument('--save_results', action='store_true',
- help='save output results')
-
- # options for residual group and feature channel reduction
- parser.add_argument('--n_resgroups', type=int, default=10,
- help='number of residual groups')
- parser.add_argument('--reduction', type=int, default=16,
- help='number of feature maps reduction')
- # options for test
- parser.add_argument('--testpath', type=str, default='../test/DIV2K_val_LR_our',
- help='dataset directory for testing')
- parser.add_argument('--testset', type=str, default='Set5',
- help='dataset name for testing')
-
- args = parser.parse_args()
- template.set_template(args)
-
- args.scale = list(map(lambda x: int(x), args.scale.split('+')))
-
- if args.epochs == 0:
- args.epochs = 1e8
-
- for arg in vars(args):
- if vars(args)[arg] == 'True':
- vars(args)[arg] = True
- elif vars(args)[arg] == 'False':
- vars(args)[arg] = False
-
template.py
- def set_template(args):
- # Set the templates here
- if args.template.find('jpeg') >= 0:
- args.data_train = 'DIV2K_jpeg'
- args.data_test = 'DIV2K_jpeg'
- args.epochs = 200
- args.lr_decay = 100
-
- if args.template.find('EDSR_paper') >= 0:
- args.model = 'EDSR'
- args.n_resblocks = 32
- args.n_feats = 256
- args.res_scale = 0.1
-
- if args.template.find('MDSR') >= 0:
- args.model = 'MDSR'
- args.patch_size = 48
- args.epochs = 650
-
- if args.template.find('DDBPN') >= 0:
- args.model = 'DDBPN'
- args.patch_size = 128
- args.scale = '4'
-
- args.data_test = 'Set5'
-
- args.batch_size = 20
- args.epochs = 1000
- args.lr_decay = 500
- args.gamma = 0.1
- args.weight_decay = 1e-4
-
- args.loss = '1*MSE'
-
- if args.template.find('GAN') >= 0:
- args.epochs = 200
- args.lr = 5e-5
- args.lr_decay = 150
-
utility.py
- import os
- import math
- import time
- import datetime
- from functools import reduce
-
- import matplotlib
- matplotlib.use('Agg')
- import matplotlib.pyplot as plt
-
- import numpy as np
- import scipy.misc as misc
-
- import torch
- import torch.optim as optim
- import torch.optim.lr_scheduler as lrs
-
- class timer():
- def __init__(self):
- self.acc = 0
- self.tic()
-
- def tic(self):
- self.t0 = time.time()
-
- def toc(self):
- return time.time() - self.t0
-
- def hold(self):
- self.acc += self.toc()
-
- def release(self):
- ret = self.acc
- self.acc = 0
-
- return ret
-
- def reset(self):
- self.acc = 0
-
- class checkpoint():
- def __init__(self, args):
- self.args = args
- self.ok = True
- self.log = torch.Tensor()
- now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
-
- if args.load == '.':
- if args.save == '.': args.save = now
- self.dir = '../experiment/' + args.save
- else:
- self.dir = '../experiment/' + args.load
- if not os.path.exists(self.dir):
- args.load = '.'
- else:
- self.log = torch.load(self.dir + '/psnr_log.pt')
- print('Continue from epoch {}...'.format(len(self.log)))
-
- if args.reset:
- os.system('rm -rf ' + self.dir)
- args.load = '.'
-
- def _make_dir(path):
- if not os.path.exists(path): os.makedirs(path)
-
- _make_dir(self.dir)
- _make_dir(self.dir + '/model')
- _make_dir(self.dir + '/results')
-
- open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'
- self.log_file = open(self.dir + '/log.txt', open_type)
- with open(self.dir + '/config.txt', open_type) as f:
- f.write(now + '\n\n')
- for arg in vars(args):
- f.write('{}: {}\n'.format(arg, getattr(args, arg)))
- f.write('\n')
-
- def save(self, trainer, epoch, is_best=False):
- trainer.model.save(self.dir, epoch, is_best=is_best)
- trainer.loss.save(self.dir)
- trainer.loss.plot_loss(self.dir, epoch)
-
- self.plot_psnr(epoch)
- torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt'))
- torch.save(
- trainer.optimizer.state_dict(),
- os.path.join(self.dir, 'optimizer.pt')
- )
-
- def add_log(self, log):
- self.log = torch.cat([self.log, log])
-
- def write_log(self, log, refresh=False):
- print(log)
- self.log_file.write(log + '\n')
- if refresh:
- self.log_file.close()
- self.log_file = open(self.dir + '/log.txt', 'a')
-
- def done(self):
- self.log_file.close()
-
- def plot_psnr(self, epoch):
- axis = np.linspace(1, epoch, epoch)
- label = 'SR on {}'.format(self.args.data_test)
- fig = plt.figure()
- plt.title(label)
- for idx_scale, scale in enumerate(self.args.scale):
- plt.plot(
- axis,
- self.log[:, idx_scale].numpy(),
- label='Scale {}'.format(scale)
- )
- plt.legend()
- plt.xlabel('Epochs')
- plt.ylabel('PSNR')
- plt.grid(True)
- plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test))
- plt.close(fig)
-
- def save_results(self, filename, save_list, scale):
- filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale)
- postfix = ('SR', 'LR', 'HR')
- for v, p in zip(save_list, postfix):
- normalized = v[0].data.mul(255 / self.args.rgb_range)
- ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
- misc.imsave('{}{}.png'.format(filename, p), ndarr)
-
- def quantize(img, rgb_range):
- pixel_range = 255 / rgb_range
- return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
-
- def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
- diff = (sr - hr).data.div(rgb_range)
- shave = scale
- if diff.size(1) > 1:
- convert = diff.new(1, 3, 1, 1)
- convert[0, 0, 0, 0] = 65.738
- convert[0, 1, 0, 0] = 129.057
- convert[0, 2, 0, 0] = 25.064
- diff.mul_(convert).div_(256)
- diff = diff.sum(dim=1, keepdim=True)
- '''
- if benchmark:
- shave = scale
- if diff.size(1) > 1:
- convert = diff.new(1, 3, 1, 1)
- convert[0, 0, 0, 0] = 65.738
- convert[0, 1, 0, 0] = 129.057
- convert[0, 2, 0, 0] = 25.064
- diff.mul_(convert).div_(256)
- diff = diff.sum(dim=1, keepdim=True)
- else:
- shave = scale + 6
- '''
- valid = diff[:, :, shave:-shave, shave:-shave]
- mse = valid.pow(2).mean()
-
- return -10 * math.log10(mse)
-
- def make_optimizer(args, my_model):
- trainable = filter(lambda x: x.requires_grad, my_model.parameters())
-
- if args.optimizer == 'SGD':
- optimizer_function = optim.SGD
- kwargs = {'momentum': args.momentum}
- elif args.optimizer == 'ADAM':
- optimizer_function = optim.Adam
- kwargs = {
- 'betas': (args.beta1, args.beta2),
- 'eps': args.epsilon
- }
- elif args.optimizer == 'RMSprop':
- optimizer_function = optim.RMSprop
- kwargs = {'eps': args.epsilon}
-
- kwargs['lr'] = args.lr
- kwargs['weight_decay'] = args.weight_decay
-
- return optimizer_function(trainable, **kwargs)
-
- def make_scheduler(args, my_optimizer):
- if args.decay_type == 'step':
- scheduler = lrs.StepLR(
- my_optimizer,
- step_size=args.lr_decay,
- gamma=args.gamma
- )
- elif args.decay_type.find('step') >= 0:
- milestones = args.decay_type.split('_')
- milestones.pop(0)
- milestones = list(map(lambda x: int(x), milestones))
- scheduler = lrs.MultiStepLR(
- my_optimizer,
- milestones=milestones,
- gamma=args.gamma
- )
-
- return scheduler
-
trainer.py
- import os
- import math
- from decimal import Decimal
-
- import utility
-
- import torch
- from torch.autograd import Variable
- from tqdm import tqdm
-
- class Trainer():
- def __init__(self, args, loader, my_model, my_loss, ckp):
- self.args = args
- self.scale = args.scale
-
- self.ckp = ckp
- self.loader_train = loader.loader_train
- self.loader_test = loader.loader_test
- self.model = my_model
- self.loss = my_loss
- self.optimizer = utility.make_optimizer(args, self.model)
- self.scheduler = utility.make_scheduler(args, self.optimizer)
-
- if self.args.load != '.':
- self.optimizer.load_state_dict(
- torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
- )
- for _ in range(len(ckp.log)): self.scheduler.step()
-
- self.error_last = 1e8
-
- def train(self):
- self.scheduler.step()
- self.loss.step()
- epoch = self.scheduler.last_epoch + 1
- lr = self.scheduler.get_lr()[0]
-
- self.ckp.write_log(
- '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
- )
- self.loss.start_log()
- self.model.train()
-
- timer_data, timer_model = utility.timer(), utility.timer()
- for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train):
- lr, hr = self.prepare([lr, hr])
- timer_data.hold()
- timer_model.tic()
-
- self.optimizer.zero_grad()
- sr = self.model(lr, idx_scale)
- loss = self.loss(sr, hr)
- if loss.item() < self.args.skip_threshold * self.error_last:
- loss.backward()
- self.optimizer.step()
- else:
- print('Skip this batch {}! (Loss: {})'.format(
- batch + 1, loss.item()
- ))
-
- timer_model.hold()
-
- if (batch + 1) % self.args.print_every == 0:
- self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
- (batch + 1) * self.args.batch_size,
- len(self.loader_train.dataset),
- self.loss.display_loss(batch),
- timer_model.release(),
- timer_data.release()))
-
- timer_data.tic()
-
- self.loss.end_log(len(self.loader_train))
- self.error_last = self.loss.log[-1, -1]
-
- def test(self):
- epoch = self.scheduler.last_epoch + 1
- self.ckp.write_log('\nEvaluation:')
- self.ckp.add_log(torch.zeros(1, len(self.scale)))
- self.model.eval()
-
- timer_test = utility.timer()
- with torch.no_grad():
- for idx_scale, scale in enumerate(self.scale):
- eval_acc = 0
- self.loader_test.dataset.set_scale(idx_scale)
- tqdm_test = tqdm(self.loader_test, ncols=80)
- for idx_img, (lr, hr, filename, _) in enumerate(tqdm_test):
- filename = filename[0]
- no_eval = (hr.nelement() == 1)
- if not no_eval:
- lr, hr = self.prepare([lr, hr])
- else:
- lr = self.prepare([lr])[0]
-
- sr = self.model(lr, idx_scale)
- sr = utility.quantize(sr, self.args.rgb_range)
-
- save_list = [sr]
- if not no_eval:
- eval_acc += utility.calc_psnr(
- sr, hr, scale, self.args.rgb_range,
- benchmark=self.loader_test.dataset.benchmark
- )
- save_list.extend([lr, hr])
-
- if self.args.save_results:
- self.ckp.save_results(filename, save_list, scale)
-
- self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test)
- best = self.ckp.log.max(0)
- self.ckp.write_log(
- '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
- self.args.data_test,
- scale,
- self.ckp.log[-1, idx_scale],
- best[0][idx_scale],
- best[1][idx_scale] + 1
- )
- )
-
- self.ckp.write_log(
- 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
- )
- if not self.args.test_only:
- self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch))
-
- def prepare(self, l, volatile=False):
- device = torch.device('cpu' if self.args.cpu else 'cuda')
- def _prepare(tensor):
- if self.args.precision == 'half': tensor = tensor.half()
- return tensor.to(device)
-
- return [_prepare(_l) for _l in l]
-
- def terminate(self):
- if self.args.test_only:
- self.test()
- return True
- else:
- epoch = self.scheduler.last_epoch + 1
- return epoch >= self.args.epochs
-
利用模型将图片四倍放大的结果如下:
输入图片:
输出图片:
输入图片:
输出图片:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。