class ResidualDenseBlock(nn.Module): """Residual Dense Block. Used in RRDB block in ESRGAN. Args: num_feat (int): Channel number of intermediate features. (中间特征的通道数) num_grow_ch (int): Channels for each growth. (每次增长的通道数) """ def __init__(self, num_feat=64, num_grow_ch=32): super(ResidualDenseBlock, self).__init__() # 输入通道数,输出通道数,卷积核,步长,填充 self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # # initialization (初始化每个卷积层的权重参数) # default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) # Empirically, we use 0.2 to scale the residual for better performance # 0.2 是残差缩放的超参数 return x5 * 0.2 + x
RDB = ResidualDenseBlock()
X =torch.rand(1, 64, 256, 256)
class RRDB(nn.Module): """Residual in Residual Dense Block. Used in RRDB-Net in ESRGAN. Args: num_feat (int): Channel number of intermediate features. num_grow_ch (int): Channels for each growth. """ def __init__(self, num_feat, num_grow_ch=32): super(RRDB, self).__init__() self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) def forward(self, x): out = self.rdb1(x) out = self.rdb2(out) out = self.rdb3(out) # Empirically, we use 0.2 to scale the residual for better performance return out * 0.2 + x
X =torch.rand(1, 64, 256, 256)
def make_layer(basic_block, num_basic_block, **kwarg): """Make layers by stacking the same blocks. Args: basic_block (nn.module): nn.module class for basic block. 基本块 num_basic_block (int): number of blocks. 基本块的个数 Returns: nn.Sequential: Stacked blocks in nn.Sequential. """ layers = [] for _ in range(num_basic_block): layers.append(basic_block(**kwarg)) return nn.Sequential(*layers) def pixel_unshuffle(x, scale): """ Pixel unshuffle. [n, c, w, h] ---> [n, c*scale*scale, w/scale, h/scale] Args: x (Tensor): Input feature with shape (b, c, hh, hw). scale (int): Downsample ratio. Returns: Tensor: the pixel unshuffled feature. """ b, c, hh, hw = x.size() out_channel = c * (scale**2) # assert hh % scale == 0 and hw % scale == 0 h = hh // scale w = hw // scale x_view = x.view(b, c, h, scale, w, scale) # [b, c, h/scale, scale, w/scale, scale] return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
x = torch.rand(1, 64, 256, 256)
feat = pixel_unshuffle(x, scale=2)
print(feat.shape) # [1, 256, 128, 128]
class RRDBNet(nn.Module): """Networks consisting of Residual in Residual Dense Block, which is used in ESRGAN. ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. We extend ESRGAN for scale x2 and scale x1. Note: This is one option for scale 1, scale 2 in RRDBNet. We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size and enlarge the channel size before feeding inputs into the main ESRGAN architecture. Args: num_in_ch (int): Channel number of inputs. (输入通道数) num_out_ch (int): Channel number of outputs. (输出通道数) num_feat (int): Channel number of intermediate features. (中间特征的通道数) Default: 64 num_block (int): Block number in the trunk network. Defaults: 23 (RDB块的个数) num_grow_ch (int): Channels for each growth. Default: 32. (增长的通道数) """ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): super(RRDBNet, self).__init__() self.scale = scale # 默认缩放因子是4,输入通道数为num_in_ch if scale == 2: # 如果缩放因子为2,输入通道数为2*2倍num_in_ch num_in_ch = num_in_ch * 4 elif scale == 1: # 如果缩放因子为1,则输入通道数为4*4倍num_in_ch num_in_ch = num_in_ch * 16 # 浅层特征提取层:1个3×3,步长为1,填充为1的卷积层 [n, c, h, w] -->[n, 64, h, w] self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) # 深层特征提取层:23个RRDB + 1个3×3,步长为1,填充为1的卷积层[n, c+64,h, w]-->[n, 64, h, w] self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) # 上采样重建层:两个上采样操作(卷积+插值) + 2个3×3,步长为1,填充为1的卷积层 # upsample self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): if self.scale == 2: # 如果缩放因子为2,长宽变为原来的1/2倍,通道数增加为原来的2倍 feat = pixel_unshuffle(x, scale=2) elif self.scale == 1: # 如果缩放因子为1,长宽变为原来的1/4倍,通道数增加为原来的4倍 feat = pixel_unshuffle(x, scale=4) else: # 默认情况下,缩放因子为4倍 feat = x # 浅层特征提取层 feat = self.conv_first(feat) # 深层特征提取层 body_feat = self.conv_body(self.body(feat)) # 残差连接:浅层输出+深层输出 feat = feat + body_feat # 上采样重建层 # upsample feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) out = self.conv_last(self.lrelu(self.conv_hr(feat))) return out
RRDBNet = RRDBNet(64, 64)
X =torch.rand(1, 64, 256, 256)
# gan loss (relativistic gan) 对抗损失
# 原始图像的判别得分
real_d_pred = self.net_d(self.gt).detach()
# 生成图像的判别得分
fake_g_pred = self.net_d(self.output)
# 真实图像的判别分数=原始图像的判别分数-生成图像的判别得分的平均值
l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False)
# 生成图像的判别分数=生成图像的判别分数-真实图像的判别得分的平均值
l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False)
# 生成器网络损失
l_g_gan = (l_g_real + l_g_fake) / 2
from torch import nn as nn from torch.nn import functional as F from torch.nn.utils import spectral_norm class VGGStyleDiscriminator(nn.Module): """VGG style discriminator with input size 128 x 128 or 256 x 256. It is used to train SRGAN, ESRGAN, and VideoGAN. Args: num_in_ch (int): Channel number of inputs. Default: 3. (输入数据的通道数,默认=3) num_feat (int): Channel number of base intermediate features.Default: 64. (中间特征的通道数,默认=64) """ def __init__(self, num_in_ch, num_feat, input_size=128): super(VGGStyleDiscriminator, self).__init__() self.input_size = input_size assert self.input_size == 128 or self.input_size == 256, (f'input size must be 128 or 256, but received {input_size}') # convx_0:卷积核为3*3,图像尺寸大小不变:(n-k+2p)/s+1 = (n-3+2*1)/1+1 = n # convx_1: 卷积核为4*4,步长为2,图像尺寸大小减半:(n-k+2p)/s+1 = (n-4+2*1)/2+1= n/2 self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True) self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False) self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True) self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False) self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True) self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False) self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True) self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False) self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True) self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False) self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True) self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False) self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True) self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True) self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False) self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True) self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True) if self.input_size == 256: self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False) self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True) self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True) self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100) self.linear2 = nn.Linear(100, 1) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.') feat = self.lrelu(self.conv0_0(x)) feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2 128/2=64 feat = self.lrelu(self.bn1_0(self.conv1_0(feat))) feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: /4 64/2=32 feat = self.lrelu(self.bn2_0(self.conv2_0(feat))) feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: /8 32/2=16 feat = self.lrelu(self.bn3_0(self.conv3_0(feat))) feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: /16 16/2=8 feat = self.lrelu(self.bn4_0(self.conv4_0(feat))) feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: /32 8/2=4 if self.input_size == 256: feat = self.lrelu(self.bn5_0(self.conv5_0(feat))) feat = self.lrelu(self.bn5_1(self.conv5_1(feat))) # output spatial size: / 64 # spatial size: (4, 4) feat = feat.view(feat.size(0), -1) # 将张量展成一行,输出为num_feat * 8 * 4 * 4的向量 feat = self.lrelu(self.linear1(feat)) # 全连接+ReLU激活 out = self.linear2(feat) # 只进行全连接,不使用激活函数 return out
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。