赞
踩
常用于医学上的影像分割
数据
标签
网络采用了最大池化进行下采样,一般的我们认为maxpool、padding会破坏图片的位置信息,所以可以用卷积层代替池化层,这样有利于保护位置信息,并且增加网络容量,而上采样不用转置卷积,因为转置卷积会产生象棋格效应,这对于图像语义信息产生了影响,所以可以采用邻近插值法代替转置卷积,由于随着网络层次的加深会产生梯度弥散问题,所以加了skip connection,可以有效地缓解这一问题,并且skip connection可以让信息流动更加通畅,这也接下来U-Net的改进版本U-Net++的提出
一般的像这种稠密信息转稠密信息,我们都可以把网络设计成类似于哑铃的结构(编解码结构),图片经过下采样过后,留下的是图片的轮廓信息,比较抽象和高级,那么可以看成是总结,而加了skip connection可以看成是增加细节部分,这也是U-Net比传统卷据FCN好的地方,而到了U-Net++作者大量的使用了skip connection这种残差结构
import torch.nn as nn import torch from torch.nn import functional #把常用的2个卷积操作简单封装下 class DoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super(DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1,stride=1), nn.BatchNorm2d(out_ch), #添加了BN层 nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1,stride=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, input): return self.conv(input) class DownSample(nn.Module): def __init__(self,in_ch,out_ch): super(DownSample, self).__init__() self.downsample = nn.Sequential( nn.Conv2d(in_ch,out_ch,2,padding=0,stride=2) ) def forward(self,x): return self.downsample(x) class UpSample(nn.Module): def __init__(self): super(UpSample, self).__init__() def forward(self,x): return functional.interpolate(x,scale_factor=2,mode='nearest') class Unet(nn.Module): def __init__(self, in_ch=1, out_ch=1): super(Unet, self).__init__() self.conv1 = DoubleConv(in_ch, 64) self.pool1 = DownSample(64,64) self.conv2 = DoubleConv(64, 128) self.pool2 = DownSample(128,128) self.conv3 = DoubleConv(128, 256) self.pool3 = DownSample(256,256) self.conv4 = DoubleConv(256, 512) self.pool4 = DownSample(512,512) self.conv5 = DoubleConv(512, 1024) # 逆卷积,也可以使用上采样 self.up6 = UpSample()#nn.ConvTranspose2d(1024, 512, 2, stride=2) self.conv6 = DoubleConv(1024+512, 512) self.up7 = UpSample() self.conv7 = DoubleConv(512+256, 256) self.up8 = UpSample() self.conv8 = DoubleConv(256+128, 128) self.up9 = UpSample() self.conv9 = DoubleConv(128+64, 64) self.conv10 = nn.Conv2d(64, out_ch, 1) def forward(self, x): c1 = self.conv1(x) # print(c1.shape) p1 = self.pool1(c1) # print(p1.shape) c2 = self.conv2(p1) # print(c2.shape) p2 = self.pool2(c2) # print(p2.shape) c3 = self.conv3(p2) # print(c3.shape) p3 = self.pool3(c3) # print(p3.shape) c4 = self.conv4(p3) # print(c4.shape) p4 = self.pool4(c4) # print(p4.shape) c5 = self.conv5(p4) # print(c5.shape) up_6 = self.up6(c5) # print(up_6.shape) # print(c4.shape) merge6 = torch.cat([up_6, c4], dim=1) # print(merge6.shape) c6 = self.conv6(merge6) # print(c6.shape) up_7 = self.up7(c6) # print(up_7.shape) merge7 = torch.cat([up_7, c3], dim=1) c7 = self.conv7(merge7) # print(c7.shape) up_8 = self.up8(c7) # print(up_8.shape) merge8 = torch.cat([up_8, c2], dim=1) c8 = self.conv8(merge8) # print(c8.shape) up_9 = self.up9(c8) # print(up_9.shape) merge9 = torch.cat([up_9, c1], dim=1) c9 = self.conv9(merge9) # print(c9.shape) c10 = self.conv10(c9) # print(c10.shape) out = nn.Sigmoid()(c10) return out
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。