当前位置:   article > 正文

pytorch中修改预训练模型_修改se resnet50模型后预训练模型还要改吗

修改se resnet50模型后预训练模型还要改吗

原文链接:https://blog.csdn.net/whut_ldz/article/details/78845947

一、pytorch中的pre-train模型

卷积神经网络的训练是耗时的,很多场合不可能每次都从随机初始化参数开始训练网络。
pytorch中自带几种常用的深度学习网络预训练模型,如VGG、ResNet等。往往为了加快学习的进度,在训练的初期我们直接加载pre-train模型中预先训练好的参数,model的加载如下所示:


  
  
  1. import torchvision.models as models
  2. #resnet
  3. model = models.ResNet(pretrained= True)
  4. model = models.resnet18(pretrained= True)
  5. model = models.resnet34(pretrained= True)
  6. model = models.resnet50(pretrained= True)
  7. #vgg
  8. model = models.VGG(pretrained= True)
  9. model = models.vgg11(pretrained= True)
  10. model = models.vgg16(pretrained= True)
  11. model = models.vgg16_bn(pretrained= True)
  • 1

二、预训练模型的修改

1.参数修改
对于简单的参数修改,这里以resnet预训练模型举例,resnet源代码在Github点击打开链接
resnet网络最后一层分类层fc是对1000种类型进行划分,对于自己的数据集,如果只有9类,修改的代码如下:

  
  
  1. # coding=UTF-8
  2. import torchvision.models as models
  3. #调用模型
  4. model = models.resnet50(pretrained= True)
  5. #提取fc层中固定的参数
  6. fc_features = model.fc.in_features
  7. #修改类别为9
  8. model.fc = nn.Linear(fc_features, 9)
  • 1

2.增减卷积层
前一种方法只适用于简单的参数修改,有的时候我们往往要修改网络中的层次结构,这时只能用参数覆盖的方法,即自己先定义一个类似的网络,再将预训练中的参数提取到自己的网络中来。这里以resnet预训练模型举例。

  
  
  1. # coding=UTF-8
  2. import torchvision.models as models
  3. import torch
  4. import torch.nn as nn
  5. import math
  6. import torch.utils.model_zoo as model_zoo
  7. class CNN(nn.Module):
  8. def __init__(self, block, layers, num_classes=9):
  9. self.inplanes = 64
  10. super(ResNet, self).__init__()
  11. self.conv1 = nn.Conv2d( 3, 64, kernel_size= 7, stride= 2, padding= 3,
  12. bias= False)
  13. self.bn1 = nn.BatchNorm2d( 64)
  14. self.relu = nn.ReLU(inplace= True)
  15. self.maxpool = nn.MaxPool2d(kernel_size= 3, stride= 2, padding= 1)
  16. self.layer1 = self._make_layer(block, 64, layers[ 0])
  17. self.layer2 = self._make_layer(block, 128, layers[ 1], stride= 2)
  18. self.layer3 = self._make_layer(block, 256, layers[ 2], stride= 2)
  19. self.layer4 = self._make_layer(block, 512, layers[ 3], stride= 2)
  20. self.avgpool = nn.AvgPool2d( 7, stride= 1)
  21. #新增一个反卷积层
  22. self.convtranspose1 = nn.ConvTranspose2d( 2048, 2048, kernel_size= 3, stride= 1, padding= 1, output_padding= 0, groups= 1, bias= False, dilation= 1)
  23. #新增一个最大池化层
  24. self.maxpool2 = nn.MaxPool2d(kernel_size= 3, stride= 1, padding= 1)
  25. #去掉原来的fc层,新增一个fclass层
  26. self.fclass = nn.Linear( 2048, num_classes)
  27. for m in self.modules():
  28. if isinstance(m, nn.Conv2d):
  29. n = m.kernel_size[ 0] * m.kernel_size[ 1] * m.out_channels
  30. m.weight.data.normal_( 0, math.sqrt( 2. / n))
  31. elif isinstance(m, nn.BatchNorm2d):
  32. m.weight.data.fill_( 1)
  33. m.bias.data.zero_()
  34. def _make_layer(self, block, planes, blocks, stride=1):
  35. downsample = None
  36. if stride != 1 or self.inplanes != planes * block.expansion:
  37. downsample = nn.Sequential(
  38. nn.Conv2d(self.inplanes, planes * block.expansion,
  39. kernel_size= 1, stride=stride, bias= False),
  40. nn.BatchNorm2d(planes * block.expansion),
  41. )
  42. layers = []
  43. layers.append(block(self.inplanes, planes, stride, downsample))
  44. self.inplanes = planes * block.expansion
  45. for i in range( 1, blocks):
  46. layers.append(block(self.inplanes, planes))
  47. return nn.Sequential(*layers)
  48. def forward(self, x):
  49. x = self.conv1(x)
  50. x = self.bn1(x)
  51. x = self.relu(x)
  52. x = self.maxpool(x)
  53. x = self.layer1(x)
  54. x = self.layer2(x)
  55. x = self.layer3(x)
  56. x = self.layer4(x)
  57. x = self.avgpool(x)
  58. #新加层的forward
  59. x = x.view(x.size( 0), -1)
  60. x = self.convtranspose1(x)
  61. x = self.maxpool2(x)
  62. x = x.view(x.size( 0), -1)
  63. x = self.fclass(x)
  64. return x
  65. #加载model
  66. resnet50 = models.resnet50(pretrained= True)
  67. cnn = CNN(Bottleneck, [ 3, 4, 6, 3])
  68. #读取参数
  69. pretrained_dict = resnet50.state_dict()
  70. model_dict = cnn.state_dict()
  71. # 将pretrained_dict里不属于model_dict的键剔除掉
  72. pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
  73. # 更新现有的model_dict
  74. model_dict.update(pretrained_dict)
  75. # 加载我们真正需要的state_dict
  76. cnn.load_state_dict(model_dict)
  77. # print(resnet50)
  78. print(cnn)
  • 1

以上就是相关的内容,本人刚入门的小白一枚,请轻喷~




声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/217621
推荐阅读
相关标签
  

闽ICP备14008679号