当前位置:   article > 正文

【深度之眼】【Pytorch打卡第16天】:模型微调Finetune(迁移学习)_模型微调训多久

模型微调训多久

任务

任务简介

了解transfer learning 与 model finetune

详细介绍

学习模型微调(Finetune)的方法,以及认识Transfer Learning(迁移学习)与Model Finetune之间的关系。


知识点

Transfer Learning & Model Finetune

迁移学习:机器学习分支,研究源域(source domain)的知识如何应用到目标域(target domain)。

模型微调:所谓的模型微调,其实就是模型的迁移学习,在深度学习中,通过不断的迭代,更新卷基层中的权值,这里的权值可以称之为 knowledge , 然后我们可以将这些 knowledge 进行迁移,主要目的是将这些 knowledge 运用到新的模型中,这样既可以减小由于数据量不足导致的过拟合现象,同时又能加快模型的训练速度

具体说来,对于卷积神经网络,我们可以把前面的卷基层,池化层看作是 feature extactor(特征提取) ,是一个非常有共性的部分。得到一系列的feature map。
后面的全连接层,可以称之为 classifier (分类器), 与具体的任务有关。这一部分就需要针对不同的训练任务进行调整,尤其是最后一层需要根据任务进行相应的调整。

PyTorch中的Finetune

模型微调步骤

  1. 获取预训练模型参数—原任务中获取得到的知识
  2. 加载模型(load_state_dict)
  3. 修改输出层

模型微调训练方法

  1. 固定预训练的参数(requires_grad =False;lr=0)
  2. Features Extractor较小学习率(params_group)

实战-Resnet-18 用于二分类

Resnet-18 模型介绍

数据下载
模型下载

蚂蚁蜜蜂二分类数据
训练集:各120~张
验证集:各70~张
在这里插入图片描述

Resnet-18模型结构如下图所示:

前面四层是特征提取,接下来四层(layer1~layer4)是残差网络,然后接avgpool池化层,最后接FC分类(原模型是1000分类,ImageNet上训练的)。

迁移结果分析

(1) 直接训练
如果不采用Resnet-18模型进行Finetune,直接对二分类数据进行训练,得到的Loss曲线

损失值一直在0.6附近,并且得到的Accuracy只有70%

(2) 迁移训练,但不冻结卷积层,固定学习率

# ============================ step 2/5 模型 ============================
# 1/3 构建模型
resnet18_ft = models.resnet18()

# 2/3 加载参数
# flag = 0
flag = 1
if flag:
    path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth")
    state_dict_load = torch.load(path_pretrained_model)
    resnet18_ft.load_state_dict(state_dict_load)

# 3/3 替换fc层
num_ftrs = resnet18_ft.fc.in_features
resnet18_ft.fc = nn.Linear(num_ftrs, classes)

resnet18_ft.to(device)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

可以看出,损失值最后收敛到在0.2附近,并且在第二个Epoch的Accuracy就达到了90%。

(3)迁移训练,冻结卷基层,固定学习率

从代码看,所谓冻结卷积层,是直接把参数的梯度设置为False

for param in resnet18_ft.parameters():
    param.requires_grad = False
  • 1
  • 2
# ============================ step 2/5 模型 ============================

# 1/3 构建模型
resnet18_ft = models.resnet18()

# 2/3 加载参数
# flag = 0
flag = 1
if flag:
    path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth")
    state_dict_load = torch.load(path_pretrained_model)
    resnet18_ft.
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/649571
推荐阅读
相关标签
  

闽ICP备14008679号