当前位置:   article > 正文

【超详细小白必懂】Pytorch 直接加载ResNet50模型和参数实现迁移学习_torchvision.models.resnet50

torchvision.models.resnet50

Torchvision.models包里面包含了常见的各种基础模型架构,主要包括以下几种:(我们以ResNet50模型作为此次演示的例子)

AlexNet
VGG
ResNet
SqueezeNet
DenseNet
Inception v3
GoogLeNet
ShuffleNet v2
MobileNet v2
ResNeXt
Wide ResNet
MNASNet

首先加载ResNet50模型,如果如果需要加载模型本身的参数,需要使用pretrained=True,代码如下

  1. import torchvision
  2. from torchvision import models
  3. resnet50 = models.resnet50(pretrained=True) #pretrained=True 加载模型以及训练过的参数
  4. print(resnet50) # 打印输出观察一下resnet50到底是怎么样的结构

打印输出后ResNet50部分结构如下图,其中红框的全连接层是需要关注的点。全连接层中,“resnet50” 的out_features=1000,这也就是说可以进行class=1000的分类。

 由于我们正常所使用的分类场景大概率与resnet50的分类数不一样,所以在调用时,要使用out_features=分类数进行调整。假设我们采用CIFAR10数据集(10 class)进行测试,那么我们就需要修改全连接层,out_features=10。具体代码如下:

  1. resnet50 = models.resnet50(pretrained=True)
  2. num_ftrs = resnet50.fc.in_features
  3. for param in resnet50.parameters():
  4. param.requires_grad = False #False:冻结模型的参数,也就是采用该模型已经训练好的原始参数。只需要训练我们自己定义的Linear层
  5. #保持in_features不变,修改out_features=10
  6. resnet50.fc = nn.Sequential(nn.Linear(num_ftrs,10),
  7. nn.LogSoftmax(dim=1))

一个简单完整的 CIFAR10+ResNet50 训练代码如下:

  1. import torch
  2. import torchvision
  3. from torch import nn
  4. from torch.utils.data import DataLoader
  5. from torchvision import models
  6. #下载CIFAR10数据集
  7. train_data = torchvision.datasets.CIFAR10(root="../data",train=True,transform=torchvision.transforms.ToTensor(),
  8. download=False)
  9. test_data = torchvision.datasets.CIFAR10(root="../data",train=False,transform=torchvision.transforms.ToTensor(),
  10. download=False)
  11. train_data_size = len(train_data)
  12. test_data_size = len(test_data)
  13. print("The size of Train_data is {}".format(train_data_size))
  14. print("The size of Test_data is {}".format(test_data_size))
  15. #dataloder进行数据集的加载
  16. train_dataloader = DataLoader(train_data,batch_size=128)
  17. test_dataloader = DataLoader(test_data,batch_size=128)
  18. resnet50 = models.resnet50(pretrained=True)
  19. num_ftrs = resnet50.fc.in_features
  20. for param in resnet50.parameters():
  21. param.requires_grad = False #False:冻结模型的参数,
  22. # 也就是采用该模型已经训练好的原始参数。
  23. #只需要训练我们自己定义的Linear层
  24. resnet50.fc = nn.Sequential(nn.Linear(num_ftrs,10),
  25. nn.LogSoftmax(dim=1))
  26. # 网络模型cuda
  27. if torch.cuda.is_available():
  28. resnet50 = resnet50.cuda()
  29. #loss
  30. loss_fn = nn.CrossEntropyLoss()
  31. if torch.cuda.is_available():
  32. loss_fn = loss_fn.cuda()
  33. #optimizer
  34. learning_rate = 0.01
  35. optimizer = torch.optim.SGD(resnet50.parameters(),lr=learning_rate,)
  36. #设置网络训练的一些参数
  37. #记录训练的次数
  38. total_train_step = 0
  39. #记录测试的次数
  40. total_test_step = 0
  41. #训练的轮数
  42. epoch = 10
  43. for i in range(epoch):
  44. print("-------第{}轮训练开始-------".format(i+1))
  45. resnet50.train()
  46. #训练步骤开始
  47. for data in train_dataloader:
  48. imgs, targets = data
  49. if torch.cuda.is_available():
  50. # 图像cuda;标签cuda
  51. # 训练集和测试集都要有
  52. imgs = imgs.cuda()
  53. targets = targets.cuda()
  54. outputs = resnet50(imgs)
  55. loss = loss_fn(outputs, targets)
  56. # 优化器优化模型
  57. optimizer.zero_grad()
  58. loss.backward()
  59. optimizer.step()
  60. total_train_step = total_train_step + 1
  61. if total_train_step % 100 == 0:
  62. print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))
  63. #writer.add_scalar("train_loss", loss.item(), total_train_step)
  64. #测试集
  65. total_test_loss = 0
  66. with torch.no_grad():
  67. for data in test_dataloader:
  68. imgs, targets = data
  69. if torch.cuda.is_available():
  70. # 图像cuda;标签cuda
  71. # 训练集和测试集都要有
  72. imgs = imgs.cuda()
  73. targets = targets.cuda()
  74. outputs = resnet50(imgs)
  75. loss = loss_fn(outputs,targets)
  76. total_test_loss += loss.item()
  77. total_test_step += 1
  78. if total_test_step % 100 ==0:
  79. print("测试次数:{},Loss:{}".format(total_test_step,total_test_loss))

完美!!!!!

剩下的大家可以举一反三,继续探索。。。。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/217587
推荐阅读
相关标签
  

闽ICP备14008679号