赞
踩
使用示例代码:
import torchvision
from torch import nn
# 加载网络
# 这一句话(当pretrained设置为False时)就相当与把网络架构在这里替换了一下,网络模型的参数都是初始化的,是默认的一些参数
vgg16_false = torchvision.models.vgg16(pretrained=False)
# 这一句话(当pretrained设置为True时)网络模型的参数都是在ImageNet数据集上训练好的,就是在ImageNet数据集上能够达到一个比较好的效果
vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16的使用有两个常用参数,分别是pretrained
和process
。
示例代码如下:
import torchvision from torch import nn # 加载网络 vgg16_false = torchvision.models.vgg16(pretrained=False) print("vgg16_false:\n",vgg16_false) vgg16_true = torchvision.models.vgg16(pretrained=True) print("vgg16_true:\n",vgg16_true) # 如何利用现有的网络去改动他的一个结构 # 1.添加网络层 # 加载CIFAR10数据集 train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor()) # 将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。 # 方式1:在整个网络中直接添加 # vgg16_true.add_module("add_linear",nn.Linear(1000,10)) # 方式2:在相应的模块中添加 vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10)) print("vgg16_true:\n",vgg16_true)
运行结果:
讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。
示例代码如下:
import torchvision from torch import nn # 加载网络模型 vgg16_false = torchvision.models.vgg16(pretrained=False) print("vgg16_false:\n",vgg16_false) vgg16_true = torchvision.models.vgg16(pretrained=True) print("vgg16_true:\n",vgg16_true) # 如何利用现有的网络去改动他的一个结构 # 2.直接修改网络 # 加载CIFAR10数据集 train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor()) # 将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。 # 按顺序对网络进行索引,修改最后的线性层 vgg16_false.classifier[6] = nn.Linear(4096,10) print("vgg16_false",vgg16_false)
运行结果:
讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。