当前位置:   article > 正文

1--图像分割之UNet_unet图像分割

unet图像分割

参考来源

https://mp.weixin.qq.com/s?__biz=MzIxODg1OTk1MA==&mid=2247484552&idx=1&sn=419a4ec2a98be9bd3b8690ed855d0f33&chksm=97e55449a092dd5fb304a110566ff99fe838683570a6063934a5067c57ad7b0019592ee519af&scene=178&cur_album_id=1350801111883612160#rd(强烈安利其他文章,写的我这个小白都能基本看懂)

1.简介

那么肯定有人问了,图像分割鼻祖不是FCN吗,怎么是UNet,因为机缘巧合下第一个所接触到的就是UNet啦,不过也没事,后期会慢慢补充,接下来回归正题。

字如其名,图像分割就是对图像进行分割,更专业的来说分为语义分割,实例分割还有全景分割。

语义分割是什么?

语义分割(semantic segmentation) : 就是按照“语义”给图像上目标类别中的每一点打一个标签,使得不同种类的东西在图像上被区分开来。可以理解成像素级别的分类任务,直白点,就是对每个像素点进行分类。

简而言之,我们的目标是给定一幅RGB彩色图像(高x宽x3)或一幅灰度图像(高x宽x1),输出一个分割图谱,其中包括每个像素的类别标注(高x宽x1)。具体如下图所示:

图片

注意:为了视觉上清晰,上面的预测图是一个低分辨率的图。在实际应用中,分割标注的分辨率需要与原始图像的分辨率相同。

这里对图片分为五类:Person(人)、Purse(包)、Plants/Grass(植物/草)、Sidewalk(人行道)、Building/Structures(建筑物)。

而实例分割就是在语义分割的基础上需要区分每个不同实例,简单的来说比如一幅图像里有三个人,语义分割就是需要把所有人抠出来,而实例分割需要区分哪个是你,哪个是我,哪个是他,全景分割的话本人不太了解,有兴趣的话可以参考下相关资料。

2.UNet网络结构及pytorch实现

医学图像分割一般会存在以下几个问题:

  • 数据量少
  • 图像尺寸大,分辨率高
  • 要求分割的结果精度高

UNet的提出为以上问题的解决奠定了基石。

UNet网络结构,最主要的两个特点是:U型网络结构和Skip Connection跳层连接。

图片

UNet是一个对称的网络结构,左侧为下采样,右侧为上采样。

按照功能可以将左侧的一系列下采样操作称为encoder,将右侧的一系列上采样操作称为decoder。

Skip Connection中间四条灰色的平行线,Skip Connection就是在上采样的过程中,融合下采样过过程中的feature map。

Skip Connection用到的融合的操作也很简单,就是将feature map的通道进行叠加,俗称Concat。

Concat操作也很好理解,举个例子:一本大小为10cm*10cm,厚度为3cm的书A,和一本大小为10cm*10cm,厚度为4cm的书B。

将书A和书B,边缘对齐地摞在一起。这样就得到了,大小为10cm*10cm厚度为7cm的一摞书,类似这种:

图片

 这种“摞在一起”的操作,就是Concat。

同样道理,对于feature map,一个大小为256*256*64的feature map,即feature map的w(宽)为256,h(高)为256,c(通道数)为64。和一个大小为256*256*32的feature map进行Concat融合,就会得到一个大小为256*256*96的feature map。

在实际使用中,Concat融合的两个feature map的大小不一定相同,例如256*256*64的feature map和240*240*32的feature map进行Concat。

这种时候,就有两种办法:

第一种:将大256*256*64的feature map进行裁剪,裁剪为240*240*64的feature map,比如上下左右,各舍弃8 pixel,裁剪后再进行Concat,得到240*240*96的feature map。

第二种:将小240*240*32的feature map进行padding操作,padding为256*256*32的feature map,比如上下左右,各补8 pixel,padding后再进行Concat,得到256*256*96的feature map。

UNet采用的Concat方案就是第二种,将小的feature map进行padding,padding的方式是补0,一种常规的常量填充。

代码部分采用模块化设计:

DoubleConv模块:

先看下连续两次的卷积操作。

图片

从UNet网络中可以看出,不管是下采样过程还是上采样过程,每一层都会连续进行两次卷积操作,这种操作在UNet网络中重复很多次,可以单独写一个DoubleConv模块:

  1. import torch.nn as nn
  2. class DoubleConv(nn.Module):
  3. """(convolution => [BN] => ReLU) * 2"""
  4. def __init__(self, in_channels, out_channels):
  5. super().__init__()
  6. self.double_conv = nn.Sequential(
  7. nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),
  8. nn.BatchNorm2d(out_channels),
  9. nn.ReLU(inplace=True),
  10. nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),
  11. nn.BatchNorm2d(out_channels),
  12. nn.ReLU(inplace=True)
  13. )
  14. def forward(self, x):
  15. return self.double_conv(x)

解释下,上述的Pytorch代码:torch.nn.Sequential是一个时序容器,Modules 会以它们传入的顺序被添加到容器中。比如上述代码的操作顺序:卷积->BN->ReLU->卷积->BN->ReLU。

DoubleConv模块的in_channels和out_channels可以灵活设定,以便扩展使用。

如上图所示的网络,in_channels设为1,out_channels为64。

输入图片大小为572*572,经过步长为1,padding为0的3*3卷积,得到570*570的feature map,再经过一次卷积得到568*568的feature map。

计算公式:O=(H−F+2×P)/S+1

H为输入feature map的大小,O为输出feature map的大小,F为卷积核的大小,P为padding的大小,S为步长。

Down模块:

图片

UNet网络一共有4次下采样过程,模块化代码如下:

  1. class Down(nn.Module):
  2. """Downscaling with maxpool then double conv"""
  3. def __init__(self, in_channels, out_channels):
  4. super().__init__()
  5. self.maxpool_conv = nn.Sequential(
  6. nn.MaxPool2d(2),
  7. DoubleConv(in_channels, out_channels)
  8. )
  9. def forward(self, x):
  10. return self.maxpool_conv(x)

其实也就是一个池化,池化后面跟了一个DoubleConv,这里大家可以看着图连贯着想,就会明白为什么在这里加DoubleConv。

至此,UNet网络的左半部分的下采样过程的代码都写好了,接下来是右半部分的上采样过程

Up模块:

上采样过程用到的最多的当然就是上采样了,除了常规的上采样操作,还有进行特征的融合。

图片

这块的代码实现起来也稍复杂一些:

  1. class Up(nn.Module):
  2. """Upscaling then double conv"""
  3. def __init__(self, in_channels, out_channels, bilinear=True):
  4. super().__init__()
  5. # if bilinear, use the normal convolutions to reduce the number of channels
  6. if bilinear:
  7. self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
  8. else:
  9. self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
  10. self.conv = DoubleConv(in_channels, out_channels)
  11. def forward(self, x1, x2):
  12. x1 = self.up(x1)
  13. # input is CHW
  14. diffY = x2.size()[2] - x1.size()[2]
  15. diffX = x2.size()[3] - x1.size()[3]
  16. x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
  17. diffY // 2, diffY - diffY // 2])
  18. # if you have padding issues, see
  19. # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
  20. # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
  21. x = torch.cat([x2, x1], dim=1)
  22. return self.conv(x)

可以分开来看,首先是__init__初始化函数里定义的上采样方法以及卷积采用DoubleConv。上采样,定义了两种方法:Upsample和ConvTranspose2d,也就是双线性插值反卷积,这里可以关注下bilinear这个参数用来控制两种模式。

双线性插值很好理解,示意图:

图片

熟悉双线性插值的朋友对于这幅图应该不陌生,简单地讲:已知Q11、Q12、Q21、Q22四个点坐标,通过Q11和Q21求R1,再通过Q12和Q22求R2,最后通过R1和R2求P,这个过程就是双线性插值。

对于一个feature map而言,其实就是在像素点中间补点,补的点的值是多少,是由相邻像素点的值决定的。

反卷积,顾名思义,就是反着卷积。卷积是让featuer map越来越小,反卷积就是让feature map越来越大,示意图:

图片

下面蓝色为原始图片,周围白色的虚线方块为padding结果,通常为0,上面绿色为卷积后的图片。

这个示意图,就是一个从2*2的feature map->4*4的feature map过程。

在forward前向传播函数中,x1接收的是上采样的数据,x2接收的是特征融合的数据。特征融合方法就是,上文提到的,先对小的feature map进行padding,再进行concat。

OutConv模块:用上述的DoubleConv模块、Down模块、Up模块就可以拼出UNet的主体网络结构了。UNet网络的输出需要根据分割数量,整合输出通道,结果如下图所示:

图片

操作很简单,就是channel的变换,上图展示的是分类为2的情况(通道为2)。

  1. class OutConv(nn.Module):
  2. def __init__(self, in_channels, out_channels):
  3. super(OutConv, self).__init__()
  4. self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
  5. def forward(self, x):
  6. return self.conv(x)

至此,UNet网络用到的模块都已经写好,我们可以将上述的模块代码都放到一个unet_parts.py文件里,然后再创建unet_model.py,根据UNet网络结构,设置每个模块的输入输出通道个数以及调用顺序,编写如下代码:

  1. """ Full assembly of the parts to form the complete network """
  2. """Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""
  3. import torch.nn.functional as F
  4. from unet_parts import *
  5. class UNet(nn.Module):
  6. def __init__(self, n_channels, n_classes, bilinear=False):
  7. super(UNet, self).__init__()
  8. self.n_channels = n_channels
  9. self.n_classes = n_classes
  10. self.bilinear = bilinear
  11. self.inc = DoubleConv(n_channels, 64)
  12. self.down1 = Down(64, 128)
  13. self.down2 = Down(128, 256)
  14. self.down3 = Down(256, 512)
  15. self.down4 = Down(512, 1024)
  16. self.up1 = Up(1024, 512, bilinear)
  17. self.up2 = Up(512, 256, bilinear)
  18. self.up3 = Up(256, 128, bilinear)
  19. self.up4 = Up(128, 64, bilinear)
  20. self.outc = OutConv(64, n_classes)
  21. def forward(self, x):
  22. x1 = self.inc(x)
  23. x2 = self.down1(x1)
  24. x3 = self.down2(x2)
  25. x4 = self.down3(x3)
  26. x5 = self.down4(x4)
  27. x = self.up1(x5, x4)
  28. x = self.up2(x, x3)
  29. x = self.up3(x, x2)
  30. x = self.up4(x, x1)
  31. logits = self.outc(x)
  32. return logits
  33. if __name__ == '__main__':
  34. net = UNet(n_channels=3, n_classes=1)
  35. print(net)

根据需要可以打印一下看看。

剩下的内容见续章。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/87651
推荐阅读
相关标签
  

闽ICP备14008679号