当前位置:   article > 正文

使用Pytorch实现DenseNet_class transition(nn.module): def_init_(self, in_ch

class transition(nn.module): def_init_(self, in_channel, theta=0.5): super(t

首先给出网络设计的完整代码:

  1. import torch.nn as nn
  2. class conv_block(nn.Module):
  3. def __init__(self, in_channel, growth_rate):
  4. super(conv_block, self).__init__()
  5. self.conv = nn.Sequential(
  6. nn.BatchNorm2d(in_channel),
  7. nn.ReLU(),
  8. nn.Conv2d(in_channel, 4*growth_rate, kernel_size=(1, 1), bias=False),
  9. nn.Conv2d(4*growth_rate, growth_rate, kernel_size=(3, 3), padding=1, bias=False)
  10. )
  11. def forward(self, x):
  12. out = self.conv(x)
  13. x = torch.cat([x, out], dim=1)
  14. return x
  15. class transition(nn.Module):
  16. def __init__(self, in_channel, theta=0.5):
  17. super(transition, self).__init__()
  18. self.conv = nn.Sequential(
  19. nn.BatchNorm2d(in_channel),
  20. nn.ReLU(),
  21. nn.Conv2d(in_channel, int(theta*in_channel), kernel_size=(1, 1)),
  22. nn.AvgPool2d(2, 2)
  23. )
  24. def forward(self, x):
  25. return self.conv(x)
  26. class densenet(nn.Module):
  27. def __init__(self, in_channel, classes_num, block_layers, growth_rate=32, theta=0.5):
  28. super(densenet, self).__init__()
  29. channels = 64
  30. self.growth_rate = growth_rate
  31. self.conv1 = nn.Sequential(
  32. nn.Conv2d(in_channel, channels, kernel_size=(7, 7), stride=2, padding=3),
  33. nn.BatchNorm2d(channels),
  34. nn.ReLU(),
  35. nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  36. )
  37. self.DB1, channels = self._make_dense_block(channels, num=block_layers[0])
  38. self.TL1 = transition(channels, theta)
  39. channels = int(channels * theta)
  40. self.DB2, channels = self._make_dense_block(channels, num=block_layers[1])
  41. self.TL2 = transition(channels, theta)
  42. channels = int(channels * theta)
  43. self.DB3, channels = self._make_dense_block(channels, num=block_layers[2])
  44. self.TL3 = transition(channels, theta)
  45. channels = int(channels * theta)
  46. self.DB4, channels = self._make_dense_block(channels, num=block_layers[3])
  47. self.global_average_pool = nn.Sequential(
  48. nn.BatchNorm2d(channels),
  49. nn.ReLU(),
  50. nn.AdaptiveAvgPool2d((1, 1))
  51. )
  52. self.fc = nn.Sequential(
  53. nn.Flatten(1, -1),
  54. nn.Linear(channels, classes_num)
  55. )
  56. def forward(self, x):
  57. x = self.conv1(x)
  58. x = self.DB1(x)
  59. x = self.TL1(x)
  60. x = self.DB2(x)
  61. x = self.TL2(x)
  62. x = self.DB3(x)
  63. x = self.TL3(x)
  64. x = self.DB4(x)
  65. x = self.global_average_pool(x)
  66. x = self.fc(x)
  67. return x
  68. def _make_dense_block(self, in_channel, num):
  69. layers = []
  70. channels = in_channel
  71. for i in range(num):
  72. block = conv_block(channels, self.growth_rate)
  73. channels += self.growth_rate
  74. layers.append(block)
  75. return nn.Sequential(*layers), channels

给出生成一个该网络实例的代码:

  1. net = densenet(in_channel=3, classes_num=10, block_layers=[6,12,24,16]
  2. growth_rate=32, theta=0.5)

这里生成了一个DenseNet-121网络,使用的数据集为分类为10类的rgb图像 (通道数为3)

网络的结构如下:

 

dense block实现 

代码中的conv_block即为上图中绿色框的部分,先使用1*1卷积来减少通道数从而减少参数量,论文中使用4*k作为该次卷积的输出通道数(k代表每个conv_block输出的通道数,也就是论文中growth rate,是固定值)接着使用3*3的卷积,通道数为k

densenet的特点在于特征的复用,体现在代码中的conv_block中forward下面的torch.cat([x, out], dim=1) 将本次(记为第i次)的输出(通道数为k)与本次的输入(通道数为 最初输入的通道数 in_channel + (i-1)*k)在通道维度上(dim=1)进行拼接。

**注意区别于resnet, resnet是将输出和输入做加法而这里是将通道进行拼接

将conv_block复用多次即为一个dense _block,表现在上图中是蓝色框住的部分,体现在代码densenet类下的_make_dense_block方法

transition实现

上图中表示为红色框住的部分,代码中即为transition类

首先使用一个1*1卷积作为bottle_neck,作用是减少通道数从而减少参数量,接着用2*2,stride=2的平均池化来缩小特征图尺寸

其中theta是论文中所使用的超参数,用theta*in_channel来表示输出的通道数,为了达到减少参数量的目的,theta取值0-1之间,论文中设为0.5

需要特别注意的是,由于nn.conv2d()中关于通道数的参数必须为整数,但是theta*in_channel为浮点数,需要进行类型转换int()

注意

虽然densenet的参数量相比于同等深度的网络来说更少,但是占用显存更多,有时候cuda会报错,可以适当调小batch_size

感谢 @视觉盛宴 大佬搭建网络的思路!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/73034
推荐阅读
相关标签
  

闽ICP备14008679号