当前位置:   article > 正文

【笔记】pytorch 修改预训练模型:实测加载预训练模型与模型随机初始化差别不大_想要更改resnet50的全连接层的输入特征数

想要更改resnet50的全连接层的输入特征数

文章目录

1. pytorch 预训练模型

卷积神经网络的训练是耗时的,很多场合不可能每次都从随机初始化参数开始训练网络。pytorch中自带几种常用的深度学习网络预训练模型,如VGG、ResNet等。往往为了加快学习的进度,在训练的初期我们直接加载pre-train模型中预先训练好的参数,model的加载如下所示:

  1. import torchvision.models as models
  2. #resnet
  3. model = models.ResNet(pretrained=True)
  4. model = models.resnet18(pretrained=True)
  5. model = models.resnet34(pretrained=True)
  6. model = models.resnet50(pretrained=True)
  7. #vgg
  8. model = models.VGG(pretrained=True)
  9. model = models.vgg11(pretrained=True)
  10. model = models.vgg16(pretrained=True)
  11. model = models.vgg16_bn(pretrained=True)

2. 修改全连接层类别数目

预训练模型以 resnet50 为例。

  1. model = torchvision.models.resnet50(pretrained=True)
  2. #提取fc层中固定的参数
  3. fc_features = model.fc.in_features
  4. #修改类别为10,重定义最后一层
  5. model.fc = nn.Linear(fc_features ,10)
  6. print(model.fc)

或者直接传入类别个数:

self.resnet = torchvision.models.resnet50(pretrained=False,num_classes=10)

3. 修改某一层卷积

预训练模型以 resnet50 为例。

  1. model = torchvision.models.resnet50(pretrained=True)
  2. # 重定义第一层卷积的输入通道数
  3. model.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)

4. 修改某几层卷积

4.1 去掉后两层(fc层和pooling层)

预训练模型以 resnet50 为例。
nn.module的model它包含一个叫做children()的函数,这个函数可以用来提取出model每一层的网络结构,在此基础上进行修改即可,修改方法如下(去除后两层):

  1. resnet_50_s = torchvision.models.resnet50(pretrained=False)
  2. resnet_layer = nn.Sequential(*list(resnet_50_s.children())[:-2])
  3. self.resnet = resnet_layer

在去掉预训练resnet模型的后两层(fc层和pooling层)后,新添加一个上采样层、池化层和分类层,构建网络代码如下:

  1. class Net_resnet50_upsample(nn.Module):
  2. def __init__(self):
  3. super(Net_resnet50_upsample, self).__init__()
  4. self.conv = nn.Conv2d(1, 3, kernel_size=1)
  5. resnet_50_s = torchvision.models.resnet50(pretrained=False)
  6. resnet_layer = nn.Sequential(*list(resnet_50_s.children())[:-2])
  7. self.resnet = resnet_layer
  8. # print(self.resnet)
  9. self.up7to14=nn.UpsamplingNearest2d(scale_factor=2)
  10. self.avgpool=nn.AvgPool2d(7,stride=2)
  11. self.fc = nn.Sequential(
  12. nn.Linear(2048 * 4 * 4, 1024),
  13. nn.ReLU(inplace=True),
  14. nn.Linear(1024, 128),
  15. nn.ReLU(inplace=True),
  16. nn.Linear(128, 10))
  17. def forward(self, x):
  18. x = self.conv(x)
  19. x = self.resnet(x)
  20. x=self.up7to14(x)
  21. x=self.avgpool(x)
  22. x = x.view(x.size(0), -1)
  23. x = self.fc(x)
  24. return x

4.2 增减多个卷积层

有的时候要修改网络中的层次结构,这时只能用参数覆盖的方法,即自己先定义一个类似的网络,再将预训练中的参数提取到自己的网络中来。这里以resnet预训练模型举例。

  1. # coding=UTF-8
  2. import torchvision.models as models
  3. import torch
  4. import torch.nn as nn
  5. import math
  6. import torch.utils.model_zoo as model_zoo
  7. #Bottleneck是一个class 里面定义了使用1*1的卷积核进行降维跟升维的一个残差块,可以在github resnet pytorch上查看
  8. class Bottleneck(nn.Module):
  9. expansion = 4
  10. def __init__(self, inplanes, planes, stride=1, downsample=None):
  11. super(Bottleneck, self).__init__()
  12. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
  13. self.bn1 = nn.BatchNorm2d(planes)
  14. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
  15. padding=1, bias=False)
  16. self.bn2 = nn.BatchNorm2d(planes)
  17. self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
  18. self.bn3 = nn.BatchNorm2d(planes * 4)
  19. self.relu = nn.ReLU(inplace=True)
  20. self.downsample = downsample
  21. self.stride = stride
  22. def forward(self, x):
  23. residual = x
  24. out = self.conv1(x)
  25. out = self.bn1(out)
  26. out = self.relu(out)
  27. out = self.conv2(out)
  28. out = self.bn2(out)
  29. out = self.relu(out)
  30. out = self.conv3(out)
  31. out = self.bn3(out)
  32. if self.downsample is not None:
  33. residual = self.downsample(x)
  34. out += residual
  35. out = self.relu(out)
  36. return out
  37. #不做修改的层不能乱取名字,否则预训练的权重参数无法传入
  38. class CNN(nn.Module):
  39. def __init__(self, block, layers, num_classes=9):
  40. self.inplanes = 64
  41. super(CNN, self).__init__()
  42. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
  43. bias=False)
  44. self.bn1 = nn.BatchNorm2d(64)
  45. self.relu = nn.ReLU(inplace=True)
  46. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  47. self.layer1 = self._make_layer(block, 64, layers[0])
  48. self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  49. self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  50. self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  51. self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))
  52. # 新增一个反卷积层
  53. self.convtranspose1 = nn.ConvTranspose2d(2048, 2048, kernel_size=3, stride=1, padding=1, output_padding=0,
  54. groups=1, bias=False, dilation=1)
  55. # 新增一个最大池化层
  56. self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
  57. # 去掉原来的fc层,新增一个fclass层
  58. self.fclass = nn.Linear(2048, num_classes)
  59. for m in self.modules():
  60. if isinstance(m, nn.Conv2d):
  61. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  62. m.weight.data.normal_(0, math.sqrt(2. / n))
  63. elif isinstance(m, nn.BatchNorm2d):
  64. m.weight.data.fill_(1)
  65. m.bias.data.zero_()
  66. def _make_layer(self, block, planes, blocks, stride=1):
  67. downsample = None
  68. if stride != 1 or self.inplanes != planes * block.expansion:
  69. downsample = nn.Sequential(
  70. nn.Conv2d(self.inplanes, planes * block.expansion,
  71. kernel_size=1, stride=stride, bias=False),
  72. nn.BatchNorm2d(planes * block.expansion),
  73. )
  74. layers = []
  75. layers.append(block(self.inplanes, planes, stride, downsample))
  76. self.inplanes = planes * block.expansion
  77. for i in range(1, blocks):
  78. layers.append(block(self.inplanes, planes))
  79. return nn.Sequential(*layers)
  80. def forward(self, x):
  81. x = self.conv1(x)
  82. x = self.bn1(x)
  83. x = self.relu(x)
  84. x = self.maxpool(x)
  85. x = self.layer1(x)
  86. x = self.layer2(x)
  87. x = self.layer3(x)
  88. x = self.layer4(x)
  89. x = self.avgpool(x)
  90. # 新加层的forward
  91. x = x.view(x.size(0), -1)
  92. x = self.convtranspose1(x)
  93. x = self.maxpool2(x)
  94. x = x.view(x.size(0), -1)
  95. x = self.fclass(x)
  96. return x
  97. # 加载model
  98. resnet50 = models.resnet50(pretrained=False)
  99. print(resnet50)
  100. cnn = CNN(Bottleneck, [3, 4, 6, 3]) #3 4 6 3 分别表示layer1 2 3 4 中Bottleneck模块的数量。res101则为3 4 23 3
  101. # 读取参数
  102. pretrained_dict = resnet50.state_dict()
  103. model_dict = cnn.state_dict()
  104. # 将pretrained_dict里不属于model_dict的键剔除掉
  105. pretrained_dict = {
  106. k: v for k, v in pretrained_dict.items() if k in model_dict}
  107. # 更新现有的model_dict
  108. model_dict.update(pretrained_dict)
  109. # 加载我们真正需要的state_dict
  110. cnn.load_state_dict(model_dict)
  111. # print(resnet50)
  112. print(cnn)

结果对比:
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/371340
推荐阅读
相关标签
  

闽ICP备14008679号