当前位置:   article > 正文

resnet18[2,2,2,2],resnet34[3,4,6,3],resnet50[3,4,6,3],resnet101[3,4,23,3],resnet152[3,8,36,3]的含义_二值的resnet18

二值的resnet18

        最近在调基于resnet框架提取图像特征的代码时,由于初次接触resnet的代码,对里面的函数调用模块中的[2,2,2,2],[3,4,6,3]的理解比较模糊,比如以下代码:

  1. def resnet18(pretrained=False, model_root=None, **kwargs):
  2. model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
  3. if pretrained:
  4. # misc.load_state_dict(model, model_urls['resnet18'], model_root)
  5. model_root = 'model/resnet18-5c106cde.pth'
  6. #加载预训练好的模型参数
  7. model_data = torch.load(model_root)
  8. #将模型参数加载到net中
  9. model.load_state_dict(model_data)
  10. return model
  11. def resnet34(pretrained=False, model_root=None, **kwargs):
  12. model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
  13. if pretrained:
  14. # misc.load_state_dict(model, model_urls['resnet34'], model_root)
  15. model_root = 'model/resnet34-333f7ec4.pth'
  16. model_data = torch.load(model_root)
  17. model.load_state_dict(model_data)
  18. return model
  19. def resnet50(pretrained=False, model_root=None, **kwargs):
  20. model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
  21. if pretrained:
  22. # misc.load_state_dict(model, model_urls['resnet50'], model_root)
  23. model_root = 'model/resnet50-19c8e357.pth'
  24. model_data = torch.load(model_root)
  25. model.load_state_dict(model_data)
  26. return model
  27. def resnet101(pretrained=False, model_root=None, **kwargs):
  28. model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
  29. if pretrained:
  30. # misc.load_state_dict(model, model_urls['resnet101'], model_root)
  31. model_root = 'model/resnet101-5d3b4d8f.pth'
  32. model_data = torch.load(model_root)
  33. model.load_state_dict(model_data)
  34. return model
  35. def resnet152(pretrained=False, model_root=None, **kwargs):
  36. model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
  37. if pretrained:
  38. # misc.load_state_dict(model, model_urls['resnet152'], model_root)
  39. model_root = 'model/resnet152-b121ed2d.pth'
  40. model_data = torch.load(model_root)
  41. model.load_state_dict(model_data)
  42. return model

         先不管上面,我们先看这张关于resnet系列的结构说明图片:

         看图片中红框标记的区域,有没有豁然开朗,resnet18的[2,2,2,2]就代表相对应的conv2-x,conv3_x,conv4-x,conv5-x模块,要重复执行2次,resnet34的[3,4,6,3]道理类似

         如果能把以上数组里数字的代表意义搞清楚,那么结合关于resnet系列结构的图片,仔细阅读接下来的代码逻辑,会更进一步加深你对resnet的理解:

  1. def conv3x3(in_planes, out_planes, stride=1):
  2. # "3x3 convolution with padding"
  3. #输入数据的通道数,输出数据的通道数
  4. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
  5. #用于resnet18,resnet34
  6. class BasicBlock(nn.Module):
  7. expansion = 1
  8. def __init__(self, inplanes, planes, stride=1, downsample=None):
  9. super(BasicBlock, self).__init__()
  10. #使存放数据的顺序与取出数据的顺序一致
  11. #比如:存,a,b,c,3,2,1,取:a,b,c,3,2,1
  12. m = OrderedDict()
  13. m['conv1'] = conv3x3(inplanes, planes, stride)
  14. m['bn1'] = nn.BatchNorm2d(planes)
  15. #inplace=True,进行覆盖运算
  16. m['relu1'] = nn.ReLU(inplace=True)
  17. m['conv2'] = conv3x3(planes, planes)
  18. m['bn2'] = nn.BatchNorm2d(planes)
  19. #nn.Sequential是一个有序容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行
  20. self.group1 = nn.Sequential(m)
  21. self.relu= nn.Sequential(nn.ReLU(inplace=True))
  22. self.downsample = downsample
  23. def forward(self, x):
  24. if self.downsample is not None:
  25. residual = self.downsample(x)
  26. else:
  27. residual = x
  28. out = self.group1(x) + residual
  29. out = self.relu(out)
  30. return out
  31. #用于resnet50,resnet101,resnet152
  32. class Bottleneck(nn.Module):
  33. expansion = 4
  34. def __init__(self, inplanes, planes, stride=1, downsample=None):
  35. super(Bottleneck, self).__init__()
  36. m = OrderedDict()
  37. m['conv1'] = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
  38. #加速收敛,提高泛化能力,用于防止过拟合
  39. #使得每一层神经网络的输入保持相同的分布
  40. m['bn1'] = nn.BatchNorm2d(planes)
  41. #使小于0的值为0,大于等于0的保持不变
  42. m['relu1'] = nn.ReLU(inplace=True)
  43. m['conv2'] = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
  44. m['bn2'] = nn.BatchNorm2d(planes)
  45. m['relu2'] = nn.ReLU(inplace=True)
  46. m['conv3'] = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
  47. m['bn3'] = nn.BatchNorm2d(planes * 4)
  48. self.group1 = nn.Sequential(m)
  49. self.relu= nn.Sequential(nn.ReLU(inplace=True))
  50. self.downsample = downsample
  51. def forward(self, x):
  52. if self.downsample is not None:
  53. residual = self.downsample(x)
  54. else:
  55. residual = x
  56. out = self.group1(x) + residual
  57. out = self.relu(out)
  58. return out
  59. class ResNet(nn.Module):
  60. def __init__(self, block, layers, num_classes=1000):
  61. self.inplanes = 64
  62. super(ResNet, self).__init__()
  63. m = OrderedDict()
  64. m['conv1'] = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
  65. m['bn1'] = nn.BatchNorm2d(64)
  66. m['relu1'] = nn.ReLU(inplace=True)
  67. m['maxpool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  68. self.group1= nn.Sequential(m)
  69. #这里的block是一个类
  70. self.layer1 = self._make_layer(block, 64, layers[0])
  71. self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  72. self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  73. self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  74. self.avgpool = nn.Sequential(nn.AvgPool2d(7))
  75. self.group2 = nn.Sequential(
  76. OrderedDict([
  77. #输入参数为:in_feature,out_feature
  78. ('fc', nn.Linear(512 * block.expansion, num_classes))
  79. ])
  80. )
  81. #用于初始化网络中的每个module
  82. #nn.modules()返回网络中的所有modules
  83. for m in self.modules():
  84. if isinstance(m, nn.Conv2d):
  85. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  86. m.weight.data.normal_(0, math.sqrt(2. / n))
  87. elif isinstance(m, nn.BatchNorm2d):
  88. m.weight.data.fill_(1)
  89. m.bias.data.zero_()
  90. #blocks代表模块要重复进行的操作次数
  91. def _make_layer(self, block, planes, blocks, stride=1):
  92. downsample = None
  93. #如果步长为1,使输入输出的通道数一致
  94. #输入通道!=输出通道*4,输入通道为64
  95. #也就是说,只要调用make_layer这个函数,downsample必执行,但执行发生在block操作之后
  96. #每个重复的卷积块的首次操作,都要在旁路连接上进行下采样操作
  97. #那么,该卷积块剩下的几次操作,就不再进行下采样操作,即在旁路连接上不进行下采样操作
  98. if stride != 1 or self.inplanes != planes * block.expansion:
  99. downsample = nn.Sequential(
  100. nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
  101. nn.BatchNorm2d(planes * block.expansion),
  102. )
  103. layers = []
  104. layers.append(block(self.inplanes, planes, stride, downsample))
  105. self.inplanes = planes * block.expansion
  106. for i in range(1, blocks):
  107. layers.append(block(self.inplanes, planes))
  108. #list数组,用*转化,将layers拆成一个个元素
  109. return nn.Sequential(*layers)
  110. def forward(self, x):
  111. x = self.group1(x)
  112. x = self.layer1(x)
  113. x = self.layer2(x)
  114. x = self.layer3(x)
  115. x = self.layer4(x)
  116. x = self.avgpool(x)
  117. #把四维张量变为2维张量后,才能作为FC的输入
  118. x = x.view(x.size(0), -1)
  119. x = self.group2(x)
  120. return x
  121. def resnet18(pretrained=False, model_root=None, **kwargs):
  122. model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
  123. if pretrained:
  124. # misc.load_state_dict(model, model_urls['resnet18'], model_root)
  125. model_root = 'model/resnet18-5c106cde.pth'
  126. #加载预训练好的模型参数
  127. model_data = torch.load(model_root)
  128. #将模型参数加载到net中
  129. model.load_state_dict(model_data)
  130. return model
  131. def resnet34(pretrained=False, model_root=None, **kwargs):
  132. model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
  133. if pretrained:
  134. # misc.load_state_dict(model, model_urls['resnet34'], model_root)
  135. model_root = 'model/resnet34-333f7ec4.pth'
  136. model_data = torch.load(model_root)
  137. model.load_state_dict(model_data)
  138. return model
  139. def resnet50(pretrained=False, model_root=None, **kwargs):
  140. model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
  141. if pretrained:
  142. # misc.load_state_dict(model, model_urls['resnet50'], model_root)
  143. model_root = 'model/resnet50-19c8e357.pth'
  144. model_data = torch.load(model_root)
  145. model.load_state_dict(model_data)
  146. return model
  147. def resnet101(pretrained=False, model_root=None, **kwargs):
  148. model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
  149. if pretrained:
  150. # misc.load_state_dict(model, model_urls['resnet101'], model_root)
  151. model_root = 'model/resnet101-5d3b4d8f.pth'
  152. model_data = torch.load(model_root)
  153. model.load_state_dict(model_data)
  154. return model
  155. def resnet152(pretrained=False, model_root=None, **kwargs):
  156. model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
  157. if pretrained:
  158. # misc.load_state_dict(model, model_urls['resnet152'], model_root)
  159. model_root = 'model/resnet152-b121ed2d.pth'
  160. model_data = torch.load(model_root)
  161. model.load_state_dict(model_data)
  162. return model
  163. def resnet18(pretrained=False, model_root=None, **kwargs):
  164. model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
  165. if pretrained:
  166. # misc.load_state_dict(model, model_urls['resnet18'], model_root)
  167. model_root = 'model/resnet18-5c106cde.pth'
  168. #加载预训练好的模型参数
  169. model_data = torch.load(model_root)
  170. #将模型参数加载到net中
  171. model.load_state_dict(model_data)
  172. return model
  173. def resnet34(pretrained=False, model_root=None, **kwargs):
  174. model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
  175. if pretrained:
  176. # misc.load_state_dict(model, model_urls['resnet34'], model_root)
  177. model_root = 'model/resnet34-333f7ec4.pth'
  178. model_data = torch.load(model_root)
  179. model.load_state_dict(model_data)
  180. return model
  181. def resnet50(pretrained=False, model_root=None, **kwargs):
  182. model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
  183. if pretrained:
  184. # misc.load_state_dict(model, model_urls['resnet50'], model_root)
  185. model_root = 'model/resnet50-19c8e357.pth'
  186. model_data = torch.load(model_root)
  187. model.load_state_dict(model_data)
  188. return model
  189. def resnet101(pretrained=False, model_root=None, **kwargs):
  190. model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
  191. if pretrained:
  192. # misc.load_state_dict(model, model_urls['resnet101'], model_root)
  193. model_root = 'model/resnet101-5d3b4d8f.pth'
  194. model_data = torch.load(model_root)
  195. model.load_state_dict(model_data)
  196. return model
  197. def resnet152(pretrained=False, model_root=None, **kwargs):
  198. model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
  199. if pretrained:
  200. # misc.load_state_dict(model, model_urls['resnet152'], model_root)
  201. model_root = 'model/resnet152-b121ed2d.pth'
  202. model_data = torch.load(model_root)
  203. model.load_state_dict(model_data)
  204. return model
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/351057
推荐阅读
相关标签
  

闽ICP备14008679号