赞
踩
一个weight layer由一个卷积层一个bn层组成。
当ch_in与,ch_out不等时,通过代码使得[b,ch_in,h,w] -> [b,ch_out,h,w],把,ch_in变成,ch_out。
forward中x与out不等时,在x前加一个extra()。
我们4个block中h和w是变化的,只是在此处表达的时候没有变。
我们进行一个小测试
blk=ResBlk(64,128)
tmp=torch.randn(2,64,32,32)
out=blk(tmp)
print(out.shape)
我们的channel越来越大,我们的长和宽保持不变,最终导致我们的参数量越来越大。
我们需要长和宽减半,我们需要在参数部分添加stride,stride为1时,输入和输出非常接近,当为2时,有可能输出为输入的一半。
blk=ResBlk(64,128,stride=2)
tmp=torch.randn(2,64,32,32)
out=blk(tmp)
print(out.shape)
blk=ResBlk(64,128,stride=4)
tmp=torch.randn(2,64,32,32)
out=blk(tmp)
print(out.shape)
如果是match,就不会报错。
进行人为的调试:
print('after conv:', x.shape)
x=self.outlay(x)
修改参数:
self.conv1=nn.Sequential( nn.Conv2d(3,64,kernel_size=3,stride=3,padding=0), nn.BatchNorm2d(64) ) # followed 4 blocks #[b,64,h,w]->[b,128,h,w] self.blk1=ResBlk(64,128,stride=2) # [b,128,h,w]->[b,2556,h,w] self.blk2=ResBlk(128,256,stride=2) # [b,256,h,w]->[b,512,h,w] self.blk3=ResBlk(256,512,stride=2) # [b,512,h,w]->[b,1024,h,w] self.blk4=ResBlk(512,512,stride=2) self.outlay=nn.Linear(512*1*1,10)
整体是先对数据做一个预处理,然后进行4个block,每一个block都由2个卷积和一个短接层组成,处理过程中数据的channel会慢慢增加,但是长和宽会减少,得到(512,512),再把这个(512)打平后送入全连接层,做一个分类的任务。这就是ResNet的一个基本结构。
import torch from torch import nn from torch.nn import functional as F class ResBlk(nn.Module): ''' resnet block ''' def __init__(self,ch_in,ch_out,stride=1): ''' :param ch_in: :param ch_out: ''' super(ResBlk, self).__init__() self.con1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1) self.bn1=nn.BatchNorm2d(ch_out) self.con2=nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1) self.bn2=nn.BatchNorm2d(ch_out) self.extra=nn.Sequential() if ch_out != ch_in: self.extra=nn.Sequential( nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride), nn.BatchNorm2d(ch_out) ) def forward(self,x): ''' :param x:[b,ch,h,w] :return: ''' out=F.relu(self.bn1(self.con1(x))) out=self.bn2(self.con2(out)) # short cut # extra model:[b,ch_in,h,w] with [b,ch_out,h,w] out=self.extra(x)+out return out class ResNet18(nn.Module): def __init__(self): super(ResNet18, self).__init__() self.conv1=nn.Sequential( nn.Conv2d(3,64,kernel_size=3,stride=3,padding=0), nn.BatchNorm2d(64) ) # followed 4 blocks #[b,64,h,w]->[b,128,h,w] self.blk1=ResBlk(64,128,stride=2) # [b,128,h,w]->[b,2556,h,w] self.blk2=ResBlk(128,256,stride=2) # [b,256,h,w]->[b,512,h,w] self.blk3=ResBlk(256,512,stride=2) # [b,512,h,w]->[b,1024,h,w] self.blk4=ResBlk(512,512,stride=2) self.outlay=nn.Linear(512*1*1,10) def forward(self,x): ''' :param x: :return: ''' x=F.relu(self.conv1(x)) # [b,64,h,w]->[b,1024,h,w] x=self.blk1(x) x=self.blk2(x) x=self.blk3(x) x=self.blk4(x) # print('after conv:', x.shape) # x=self.outlay(x) x=F.adaptive_avg_pool2d(x,[1,1]) x=x.view(x.size(0),-1) x=self.outlay(x) return x def main(): blk=ResBlk(64,128,stride=4) tmp=torch.randn(2,64,32,32) out=blk(tmp) print('block:',out.shape) x=torch.randn(2,3,32,32) model=ResNet18() out=model(x) print('resnet:',out.shape) if __name__ == '__main__': main()
参考up主:https://www.bilibili.com/video/BV1J3411C7zd?vd_source=a0d4f7000e77468aec70dc618794d26f
实线与虚线的区别就是相加的维度是否相同。
对于右面,[56,56,64]与[28,28,128]维度不同,高和宽通过stride=2改变,深度64到128通过1×1的卷积核进行升维。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。