赞
踩
随着深度学习领域中各类算法的迅速发展,卷积神经网络(CNN)被广泛应用在了分类任务上,输出的结果是整个图像的类标签。在生物医学领域,医生需要对病人的病灶区域进行病理分析,这时需要一种更先进的网络模型,即能通过少量的图片训练集,就能实现对像素点类别的预测,并且可以对像素点进行着色绘图,形成更复杂、严谨的判断。于是U-Net网络被设计了出来。
U-Net网络结构最早由Ronneberger等人于2015年提出。该图像的核心思想是引入了跳跃连接,使得图像分割的精度大大提升。
U-Net网络的主要结构包括了解码器、编码器、瓶颈层三个部分。
编码器:包括了四个程序块。每个程序块都包括 3 × 3 3\times3 3×3 的卷积(使用Relu激活函数),步长为 2 2 2 的 2 × 2 2\times2 2×2 的池化层(下采样)。每个程序块处理后,特征图逐步减小。
解码器: 与编码器部分对称,也包括四个程序块,每个程序块包括步长为 2 2 2 的 2 × 2 2\times2 2×2 的上采样操作,然后与编码部分进行特征映射级联(Concatenate),即拼接,最后通过两个 3 × 3 3\times3 3×3 的卷积(Relu)。
瓶颈层:包含两个 3 × 3 3\times3 3×3 的卷积层。
最后经过一个
1
×
1
1\times1
1×1的卷积层得到最后的输出。
如图所示,该网络模型形似字母“U”,故称为U-Net。
整体过程:
先对图片进行卷积和池化。比如说一开始输入的图片大小是
224
×
224
224\times224
224×224,进过四次池化后,分别得到
112
×
112
112\times112
112×112 ,
56
×
56
56\times56
56×56 ,
28
×
28
28 \times 28
28×28,
14
×
14
14 \times 14
14×14 四个不同尺寸的特征图。然后对
14
×
14
14\times 14
14×14 的特征图做上采样,得到
28
×
28
28\times28
28×28 的特征图。将这个
28
×
28
28\times28
28×28的特征图与之前池化得到的
28
×
28
28\times28
28×28 特征图进行通道上的拼接(concat),然后再对拼接之后的特征图做卷积和上采样,得到
56
×
56
56\times 56
56×56 的特征图,然后再与之前的
56
×
56
56\times56
56×56 拼接,卷积然后再上采样,经过四次就就可以得到一个与原输入图像大小相同的图片了。
在本图片上的U-Net中,它输入大小为 572 × 572 572\times 572 572×572, 而输出大小为 388 × 388 388 \times 388 388×388, 那是因为它在卷积过程中没有加padding层所造成的。
import torch import torch.nn as nn import torch.nn.functional as F # Double Convolution class DoubleConv2d(nn.Module): def __init__(self, inputChannel, outputChannel): super(DoubleConv2d, self).__init__() self.conv = nn.Sequential( nn.Conv2d(inputChannel, outputChannel, kernel_size=3, padding=1), nn.BatchNorm2d(outputChannel), nn.ReLU(True), nn.Conv2d(outputChannel, outputChannel, kernel_size=3, padding=1), nn.BatchNorm2d(outputChannel), nn.ReLU(True) ) def forward(self, x): out = self.conv(x) return out # Down Sampling class DownSampling(nn.Module): def __init__(self): super(DownSampling, self).__init__() self.down = nn.MaxPool2d(kernel_size=2) def forward(self, x): out = self.down(x) return out # Up Sampling class UpSampling(nn.Module): # Use the deconvolution def __init__(self, inputChannel, outputChannel): super(UpSampling, self).__init__() self.up = nn.Sequential( nn.ConvTranspose2d(inputChannel, outputChannel, kernel_size=2, stride=2), nn.BatchNorm2d(outputChannel) ) def forward(self, x, y): x =self.up(x) diffY = y.size()[2] - x.size()[2] diffX = y.size()[3] - x.size()[3] x = F.pad(x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) out = torch.cat([y, x], dim=1) return out class Unet(nn.Module): def __init__(self): super(Unet, self).__init__() self.layer1 = DoubleConv2d(1, 64) self.layer2 = DoubleConv2d(64, 128) self.layer3 = DoubleConv2d(128, 256) self.layer4 = DoubleConv2d(256, 512) self.layer5 = DoubleConv2d(512, 1024) self.layer6 = DoubleConv2d(1024, 512) self.layer7 = DoubleConv2d(512, 256) self.layer8 = DoubleConv2d(256, 128) self.layer9 = DoubleConv2d(128, 64) self.layer10 = nn.Conv2d(64, 2, kernel_size=3, padding=1) # The last output layer self.down = DownSampling() self.up1 = UpSampling(1024, 512) self.up2 = UpSampling(512, 256) self.up3 = UpSampling(256, 128) self.up4 = UpSampling(128, 64) def forward(self, x): conv1 = self.layer1(x) down1 = self.down(conv1) conv2 = self.layer2(down1) down2 = self.down(conv2) conv3 = self.layer3(down2) down3 = self.down(conv3) conv4 = self.layer4(down3) down4 = self.down(conv4) conv5 = self.layer5(down4) up1 = self.up1(conv5, conv4) conv6 = self.layer6(up1) up2 = self.up2(conv6, conv3) conv7 = self.layer7(up2) up3 = self.up3(conv7, conv2) conv8 = self.layer8(up3) up4 = self.up4(conv8, conv1) conv9 = self.layer9(up4) out = self.layer10(conv9) return out # Test part mynet = Unet() # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # mynet.to(device) input = torch.rand(3, 1, 572, 572) # output = mynet(input.to(device)) output = mynet(input) print(output.shape) # (3,2,572,572)
https://www.jianshu.com/p/a73f74992b1a
https://arxiv.org/pdf/1505.04597v1.pdf
https://blog.csdn.net/qq_34107425/article/details/110184747
https://blog.csdn.net/weixin_41857483/article/details/120768804
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。