赞
踩
在一些常见的如检测、分类等计算机视觉任务中,基于深度学习的方法取得了很好的结果,其中一些经典模型也往往成为相关任务及比赛的baseline。在pytorch的视觉库torchvision中,提供了models模块供我们直接调用这些经典网络,如VGG,Resnet等。使用中往往不能直接使用现成的模型,需要进行一些修改。实际上我们可以很方便的在pytorch中使用并修改模型。
直接通过models调用即可,如
from torchvision import models
res101 = models.resnet101(pretrained=True)
vgg19 = models.vgg19_bn(pretrained=True)
如下如所示,models模块的__init__.py 包含了一系列不同的网络结构
以及网络模型的不同层数的结构,如resnet50, resnet101, vgg16, vgg19等
我们只需查阅手册或源码寻找是否有这个网络模型,有的话直接拿来用即可。参数 pre_trained为True时表示模型参数是在ImageNet预训练过的,否则就是随机初始化的参数。
在首次使用时,pytorch会自动下载模型文件,保存在用户cache目录内
在参加一些图像检测、分类、分割比赛时,或者一些不需要大幅修改网络结构的场景,可以直接采用pytorch自带的网络结构,无需自行搭建。
在分类问题上,模型的最后一层一般是一个全连接层,输出的神经元个数就是类别信息,最后输出结果是一个浮点向量,大小表示某一类别的可能性,数值越大说明越倾向于分为该类。
显然直接使用预训练的网络不加修改那么总类别数就是固定的,当我们使用的场景类别数不一致时,就要自行修改模型的最后一层。那么如何进行替换和修改呢?
我们知道,在自定义网络结构时,通常是:
class myModel(Module):
def __init__(self):
# 模型结构
self.conv1 = xxxx
self.fc1 = xxxxx
self.m = nn.Sequential(a,b,c...)
def forward(self,x):
# 前向传播
这样的形式。换言之,模型的每一层都记录在了这个模型类的实例的成员变量里。因此只要我们知道要修改的那一层叫什么名字,就能够进行修改。
例如 对resnet,最后一层全连接层就叫fc,所以我们可以:
res101 = models.resnet101(pretrained=True)
numFit = res101.fc.in_features
res101.fc = nn.Linear(numFit, numClass)
res101.fc就是这个网络的最后一层全连接层,in_feature是输出神经元数量,我们将它修改为输入神经元不变(也不能变,不然就出错了)输出神经元为我们需要的类别数的全连接网络。
有时候分类任务不光要输出类别,也要输出置信度,通常置信度就是分类为这个类别的概率,既然是概率,就要满足
0
≤
P
i
≤
1
,
Σ
i
=
1
N
P
i
=
1
0\le P_{i}\le 1, \Sigma_{i=1}^{N}P_{i} = 1
0≤Pi≤1,Σi=1NPi=1
由于全连接网络直接输出的结果往往不能称之为“置信度”(只有大小之分,不满足0-1之间,和也不是1),通常会在后面加一层softmax作为激活函数,这样输出结果就是一个概率值了:
res101.fc = nn.Sequential(nn.Linear(numFit, numClass), nn.Softmax(dim=1))
以上是最简单的模型某一层就是一个单独的成员变量的情况。
那如果模型把好多东西塞进了一个Sequential怎么办呢?
例如vgg:
我们在torchvision/moduls/vgg.py 中找到VGG类的定义:
显然分类相关的3层全连接、激活函数、dropout都在一个Sequential类、名字叫做classifier的成员变量里,这种情况,我们需要把整个classifier都复写吗?
答案是不需要的。
我们知道Sequential同样继承自nn.Module类,这个类有一个成员变量叫做_modules
这是一个有序字典,存放了模块名称 - 模块内容 的键值对。
每次新添加一层,都会做一次
self._modules[name] = module
这个操作。
这个name这里很有意思,一般我们很少给每一层网络都起一个名字,那默认的名字实际上是该模块索引的字符串形式。比如上述的vgg的classifier,它的第一层全连接,名字叫做’1’,最后一层name是’6’。这个名字部分以后有时间专门讨论一下。
回到这个问题,Sequential继承自nn.Module,自然也有这个字典。
所以对于vgg,我们可以:
vgg19 = models.vgg19_bn(pretrained=True)
vgg19.classifier._modules['6'] = nn.Sequential(nn.Linear(4096, numClass), nn.Softmax(dim=1))
就可以将最后一层全连接层替换掉了。
中间其他层也可以用类似的方式替换。
总结来说,pytorch提供的网络模型还是比较实用的,对于不需要大幅修改的网络结构只要直接调用再局部修改就可以,满足一了一些简单的深度学习需求场景,可以不需要自己重新写一遍网路结构了。
在修改方面,基于pytorch模型的定义方式,我们只要知道其模型结构,这一点可以直接查找pytorch这部分的源码,了解到成员变量的名字,如果是Sequential,可以再通过_modules这个字典查找,都将能够较容易的找到被修改的那一层,直接替换成我们需要的结构即可。当然,替换后与训练的权重就不见了,取而代之的是随机初始化权重。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。