当前位置:   article > 正文

超分之ESRGAN官方代码解读_residual-in-residual dense block、

residual-in-residual dense block、

改进1:生成器RRDBNet_arch.py.py

  • 引入了没有BN层的Residual-in-Residual Dense Block(RRDB)作为基本网络构建单元
    在这里插入图片描述
    在这里插入图片描述

1.1 RDB(Residual Dense Block)

class ResidualDenseBlock(nn.Module):
    """Residual Dense Block.

    Used in RRDB block in ESRGAN.

    Args:
        num_feat (int): Channel number of intermediate features. (中间特征的通道数)
        num_grow_ch (int): Channels for each growth. (每次增长的通道数)
    """

    def __init__(self, num_feat=64, num_grow_ch=32):
        super(ResidualDenseBlock, self).__init__()
        # 输入通道数,输出通道数,卷积核,步长,填充
        self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
        self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # # initialization (初始化每个卷积层的权重参数)
        # default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        # Empirically, we use 0.2 to scale the residual for better performance
        # 0.2 是残差缩放的超参数
        return x5 * 0.2 + x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

查看RDB中具体的卷积操作

RDB = ResidualDenseBlock()
RDB
  • 1
  • 2

在这里插入图片描述
测试RDB

X =torch.rand(1, 64, 256, 256)
Y =RDB(X)
Y.shape
  • 1
  • 2
  • 3

在这里插入图片描述

1.2 RRDB(Residual in Residual Dense Block)

class RRDB(nn.Module):
    """Residual in Residual Dense Block.

    Used in RRDB-Net in ESRGAN.

    Args:
        num_feat (int): Channel number of intermediate features.
        num_grow_ch (int): Channels for each growth.
    """

    def __init__(self, num_feat, num_grow_ch=32):
        super(RRDB, self).__init__()
        self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)

    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        # Empirically, we use 0.2 to scale the residual for better performance
        return out * 0.2 + x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

查看RRDB中具体的卷积操作

RRDB = RRDB(64)
RRDB
  • 1
  • 2

在这里插入图片描述
测试RDB

X =torch.rand(1, 64, 256, 256)
Y =RRDB(X)
Y.shape
  • 1
  • 2
  • 3

在这里插入图片描述

1.3 RRDBNet(Networks consisting of Residual in Residual Dense Block)

  • make_layer(): 顺序生成指定数量的基本块
  • pixel_unshuffle():像素重组的逆过程 (t通道数增加,长宽缩小)
def make_layer(basic_block, num_basic_block, **kwarg):
    """Make layers by stacking the same blocks.

    Args:
        basic_block (nn.module): nn.module class for basic block. 基本块
        num_basic_block (int): number of blocks. 基本块的个数

    Returns:
        nn.Sequential: Stacked blocks in nn.Sequential.
    """
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwarg))
    return nn.Sequential(*layers)


def pixel_unshuffle(x, scale):
    """ Pixel unshuffle.

    [n, c, w, h]  ---> [n, c*scale*scale, w/scale, h/scale]
    Args:
        x (Tensor): Input feature with shape (b, c, hh, hw).
        scale (int): Downsample ratio.

    Returns:
        Tensor: the pixel unshuffled feature.
    """
    b, c, hh, hw = x.size()
    out_channel = c * (scale**2)  #
    assert hh % scale == 0 and hw % scale == 0
    h = hh // scale
    w = hw // scale
    x_view = x.view(b, c, h, scale, w, scale)  # [b, c, h/scale, scale, w/scale, scale]
    return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
x = torch.rand(1, 64, 256, 256)
feat = pixel_unshuffle(x, scale=2)
print(feat.shape)  # [1, 256, 128, 128]
  • 1
  • 2
  • 3

在这里插入图片描述

class RRDBNet(nn.Module):
    """Networks consisting of Residual in Residual Dense Block, which is used
    in ESRGAN.

    ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.

    We extend ESRGAN for scale x2 and scale x1.
    Note: This is one option for scale 1, scale 2 in RRDBNet.
    We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
    and enlarge the channel size before feeding inputs into the main ESRGAN architecture.

    Args:
        num_in_ch (int): Channel number of inputs. (输入通道数)
        num_out_ch (int): Channel number of outputs. (输出通道数)
        num_feat (int): Channel number of intermediate features. (中间特征的通道数)
            Default: 64
        num_block (int): Block number in the trunk network. Defaults: 23 (RDB块的个数)
        num_grow_ch (int): Channels for each growth. Default: 32. (增长的通道数)
    """

    def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
        super(RRDBNet, self).__init__()

        self.scale = scale
        # 默认缩放因子是4,输入通道数为num_in_ch
        if scale == 2:
            # 如果缩放因子为2,输入通道数为2*2倍num_in_ch
            num_in_ch = num_in_ch * 4
        elif scale == 1:
            # 如果缩放因子为1,则输入通道数为4*4倍num_in_ch
            num_in_ch = num_in_ch * 16

        # 浅层特征提取层:1个3×3,步长为1,填充为1的卷积层 [n, c, h, w] -->[n, 64, h, w]
        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)

        # 深层特征提取层:23个RRDB + 1个3×3,步长为1,填充为1的卷积层[n, c+64,h, w]-->[n, 64, h, w]
        self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
        self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)

        # 上采样重建层:两个上采样操作(卷积+插值) + 2个3×3,步长为1,填充为1的卷积层
        # upsample
        self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        if self.scale == 2:
            # 如果缩放因子为2,长宽变为原来的1/2倍,通道数增加为原来的2倍
            feat = pixel_unshuffle(x, scale=2)
        elif self.scale == 1:
            # 如果缩放因子为1,长宽变为原来的1/4倍,通道数增加为原来的4倍
            feat = pixel_unshuffle(x, scale=4)
        else:
            # 默认情况下,缩放因子为4倍
            feat = x

        # 浅层特征提取层
        feat = self.conv_first(feat)

        # 深层特征提取层
        body_feat = self.conv_body(self.body(feat))

        # 残差连接:浅层输出+深层输出
        feat = feat + body_feat

        # 上采样重建层
        # upsample
        feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
        feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.conv_hr(feat)))
        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74

查看RRDBNet的网络结构

RRDBNet = RRDBNet(64, 64)
RRDBNet
  • 1
  • 2

在这里插入图片描述
在这里插入图片描述
测试RRDBNet

X =torch.rand(1, 64, 256, 256)
Y =RRDBNet(X)
print(Y.shape)
  • 1
  • 2
  • 3

在这里插入图片描述

改进2:相对判别器

在这里插入图片描述

# gan loss (relativistic gan)  对抗损失
# 原始图像的判别得分
real_d_pred = self.net_d(self.gt).detach()
# 生成图像的判别得分
fake_g_pred = self.net_d(self.output)
# 真实图像的判别分数=原始图像的判别分数-生成图像的判别得分的平均值
l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False)
# 生成图像的判别分数=生成图像的判别分数-真实图像的判别得分的平均值
l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False)
# 生成器网络损失
l_g_gan = (l_g_real + l_g_fake) / 2
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

改进3:使用激活之前的VGG特征计算损失函数

from torch import nn as nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm


class VGGStyleDiscriminator(nn.Module):
    """VGG style discriminator with input size 128 x 128 or 256 x 256.

    It is used to train SRGAN, ESRGAN, and VideoGAN.

    Args:
        num_in_ch (int): Channel number of inputs. Default: 3. (输入数据的通道数,默认=3)
        num_feat (int): Channel number of base intermediate features.Default: 64. (中间特征的通道数,默认=64)
    """

    def __init__(self, num_in_ch, num_feat, input_size=128):
        super(VGGStyleDiscriminator, self).__init__()
        self.input_size = input_size
        assert self.input_size == 128 or self.input_size == 256, (f'input size must be 128 or 256, but received {input_size}')

        # convx_0:卷积核为3*3,图像尺寸大小不变:(n-k+2p)/s+1 = (n-3+2*1)/1+1 = n
        # convx_1: 卷积核为4*4,步长为2,图像尺寸大小减半:(n-k+2p)/s+1 = (n-4+2*1)/2+1= n/2
        self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
        self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
        self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True)

        self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False)
        self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True)
        self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False)
        self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True)

        self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False)
        self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True)
        self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False)
        self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True)

        self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False)
        self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
        self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
        self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)

        self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
        self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
        self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
        self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)

        if self.input_size == 256:
            self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
            self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
            self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
            self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True)

        self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100)
        self.linear2 = nn.Linear(100, 1)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.')

        feat = self.lrelu(self.conv0_0(x))
        feat = self.lrelu(self.bn0_1(self.conv0_1(feat)))  # output spatial size: /2  128/2=64

        feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
        feat = self.lrelu(self.bn1_1(self.conv1_1(feat)))  # output spatial size: /4  64/2=32

        feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
        feat = self.lrelu(self.bn2_1(self.conv2_1(feat)))  # output spatial size: /8  32/2=16

        feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
        feat = self.lrelu(self.bn3_1(self.conv3_1(feat)))  # output spatial size: /16  16/2=8

        feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
        feat = self.lrelu(self.bn4_1(self.conv4_1(feat)))  # output spatial size: /32  8/2=4

        if self.input_size == 256:
            feat = self.lrelu(self.bn5_0(self.conv5_0(feat)))
            feat = self.lrelu(self.bn5_1(self.conv5_1(feat)))  # output spatial size: / 64

        # spatial size: (4, 4)
        feat = feat.view(feat.size(0), -1)  # 将张量展成一行,输出为num_feat * 8 * 4 * 4的向量
        feat = self.lrelu(self.linear1(feat))  # 全连接+ReLU激活
        out = self.linear2(feat)  # 只进行全连接,不使用激活函数
        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
  

闽ICP备14008679号