当前位置:   article > 正文

先有网络模型的使用及修改_修改网络模型的方法 csdn博客

修改网络模型的方法 csdn博客

先有网络模型的使用

使用示例代码:

import torchvision
from torch import nn

# 加载网络

# 这一句话(当pretrained设置为False时)就相当与把网络架构在这里替换了一下,网络模型的参数都是初始化的,是默认的一些参数
vgg16_false = torchvision.models.vgg16(pretrained=False)

# 这一句话(当pretrained设置为True时)网络模型的参数都是在ImageNet数据集上训练好的,就是在ImageNet数据集上能够达到一个比较好的效果
vgg16_true = torchvision.models.vgg16(pretrained=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

vgg16的使用有两个常用参数,分别是pretrainedprocess

  • pretrained - 为True的话,说明这个网络是已经训练好的在训练数据集上有比较好的效果 若为False则说明这个网络是没训练的
  • process - 为True则显示下载神经网络参数的进度条若为False则不显示下载神经网络参数的进度条
    通俗来理解pretrained,就相当于什么呢?比如搭建神经网络卷积层时,你给了一个kernel_size但是并没有kernel_size中的参数,pretrained=True时相当于你得到了一个带参数的卷积核,pretrained=False时相当于你只知道这个卷积核的大小。

先有网络模型的修改(如何利用现有的网络去改动它的一个结构)

1.添加网络层

示例代码如下:

import torchvision
from torch import nn

# 加载网络
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)

vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)

# 如何利用现有的网络去改动他的一个结构

# 1.添加网络层

# 加载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())

# 将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。

# 方式1:在整个网络中直接添加
# vgg16_true.add_module("add_linear",nn.Linear(1000,10))

# 方式2:在相应的模块中添加
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))

print("vgg16_true:\n",vgg16_true)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

运行结果:
在这里插入图片描述
在这里插入图片描述

讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。

2.直接修改网络

示例代码如下:

import torchvision
from torch import nn

# 加载网络模型
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)

vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)

# 如何利用现有的网络去改动他的一个结构
# 2.直接修改网络

# 加载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())

# 将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。

# 按顺序对网络进行索引,修改最后的线性层 
vgg16_false.classifier[6] = nn.Linear(4096,10)
print("vgg16_false",vgg16_false)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

运行结果:
在这里插入图片描述
在这里插入图片描述

讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/217593
推荐阅读
相关标签
  

闽ICP备14008679号