当前位置:   article > 正文

《动手学深度学习Pytorch版》之DenseNet代码理解_conv_blok

conv_blok

稠密块由多个 conv_block 组成,每块使⽤相同的输出通道数。但在前向计算时,我们将每块的输⼊和输出在通道维上连结

一、模块介绍

1、卷积块conv_block

conv_block的作用:批量归⼀化、激活和卷积。由卷积参数可知,卷积后数据尺寸(高和宽)不变,仅改变输出通道数。

def conv_block(in_channels, out_channels):
    blk = nn.Sequential(nn.BatchNorm2d(in_channels), 
                        nn.ReLU(),
                        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
    return blk
  • 1
  • 2
  • 3
  • 4
  • 5

2、稠密块DenseBlock

DenseBlock的作用:在前向计算时,我们将每块的输⼊和输出在通道维上连结。

class DenseBlock(nn.Module):
    def __init__(self, num_convs, in_channels, out_channels):
        super(DenseBlock, self).__init__()
        net = []
        for i in range(num_convs):
            in_c = in_channels + i * out_channels  #为了将输入和输出在通道维上连结,第一个卷极块输出通道数为13,因此确保第二次卷积通道数为13(正好是上一个卷极块的输入通道数和输出通道数之和)
            net.append(conv_block(in_c, out_channels))
        self.net = nn.ModuleList(net)
        self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数

    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            X = torch.cat((X, Y), dim=1)  # 在通道维上将输入和输出连结
        return X
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

举例1

net = DenseBlock(2,3,10)
net
  • 1
  • 2

输出1

DenseBlock(
  (net): ModuleList(
    (0): Sequential(
      (0): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): ReLU()
      (2): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): Sequential(
      (0): BatchNorm2d(13, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): ReLU()
      (2): Conv2d(13, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

可见,经过一次卷积后,输出通道数变为输入数据的输入通道数和输出通道数之和。下一次卷积的输入通道数,变为上一次卷积的输出通道数。
举例2
我们定义⼀个有2个输出通道数为10的卷积块。使⽤通道数为3的输⼊时,我们会得到通道数为 3 + 2 × 10 = 23 3+2\times10 = 23 3+2×10=23的输出。卷积块的通道数控制了输出通道数相对于输⼊通道数的增⻓,因此也被称为增⻓率(growth rate)。

blk = DenseBlock(2, 3, 10)
X = torch.rand(4, 3, 8, 8)
Y = blk(X)
Y.shape
  • 1
  • 2
  • 3
  • 4

输出2

torch.Size([4, 23, 8, 8])
  • 1

3、过渡块transition_block

transition_block的作用:降低模型复杂度。由于每个稠密块都会带来通道数的增加,使⽤过多则会带来过于复杂的模型。过渡层⽤来控制模型复杂度。它通过 1 × 1 1\times1 1×1卷积层来减⼩通道数,并使⽤步幅为2的平均池化层减半⾼和宽。

def transition_block(in_channels, out_channels):
    blk = nn.Sequential(
            nn.BatchNorm2d(in_channels), 
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.AvgPool2d(kernel_size=2, stride=2))
    return blk
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

举例

blk = transition_block(23, 10)
blk(Y).shape
  • 1
  • 2

输出

torch.Size([4, 10, 4, 4])
  • 1

二、DENSNET模型

1、DenseNet首先使用同ResNet⼀样的单卷积层和最大池化层

注意到这里只能接收输入通道数为1的数据(卷积层的输入通道数为1)。

net = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
        nn.BatchNorm2d(64), 
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
  • 1
  • 2
  • 3
  • 4
  • 5

2、 DenseNet使用4个稠密块

类似于ResNet接下来使⽤的4个残差块。同ResNet⼀样,我们可以设置每个稠密块使⽤多少个卷积层。这⾥我们设成4,从⽽与上⼀节的ResNet-18保持⼀致。

num_channels, growth_rate = 64, 32  # num_channels为当前的通道数
num_convs_in_dense_blocks = [4, 4, 4, 4]

for i, num_convs in enumerate(num_convs_in_dense_blocks):
    DB = DenseBlock(num_convs, num_channels, growth_rate)
    net.add_module("DenseBlosk_%d" % i, DB)
    # 上一个稠密块的输出通道数
    num_channels = DB.out_channels
    # 在稠密块之间加入通道数减半的过渡层
    if i != len(num_convs_in_dense_blocks) - 1:
        net.add_module("transition_block_%d" % i, transition_block(num_channels, num_channels // 2))
        num_channels = num_channels // 2
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

3、同ResNet⼀样,最后接上全局池化层和全连接层来输出

net.add_module("BN", nn.BatchNorm2d(num_channels))
net.add_module("relu", nn.ReLU())
net.add_module("global_avg_pool", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)
net.add_module("fc", nn.Sequential(d2l.FlattenLayer(), nn.Linear(num_channels, 10))) 
  • 1
  • 2
  • 3
  • 4

举例
可以看到X的通道数为1

X = torch.rand((10, 1, 96, 96))
for name, layer in net.named_children():
    X = layer(X)
    print(name, ' output shape:\t', X.shape)
  • 1
  • 2
  • 3
  • 4

输出

0  output shape:	 torch.Size([10, 64, 48, 48])
1  output shape:	 torch.Size([10, 64, 48, 48])
2  output shape:	 torch.Size([10, 64, 48, 48])
3  output shape:	 torch.Size([10, 64, 24, 24])
DenseBlosk_0  output shape:	 torch.Size([10, 192, 24, 24])
transition_block_0  output shape:	 torch.Size([10, 96, 12, 12])
DenseBlosk_1  output shape:	 torch.Size([10, 224, 12, 12])
transition_block_1  output shape:	 torch.Size([10, 112, 6, 6])
DenseBlosk_2  output shape:	 torch.Size([10, 240, 6, 6])
transition_block_2  output shape:	 torch.Size([10, 120, 3, 3])
DenseBlosk_3  output shape:	 torch.Size([10, 248, 3, 3])
BN  output shape:	 torch.Size([10, 248, 3, 3])
relu  output shape:	 torch.Size([10, 248, 3, 3])
global_avg_pool  output shape:	 torch.Size([10, 248, 1, 1])
fc  output shape:	 torch.Size([10, 10])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/373308?site
推荐阅读
相关标签
  

闽ICP备14008679号