[BN] => ReLU) * 2 连续两次的卷积操作:U-net网络中,下采样和上采样过程,每一层都会连续进行两次卷积操作 """ def __init__(self, in_chann_pytorch u-net">
赞
踩
U-Net网络实际上是用于语义分割的,但深度学习在语义分割上的开山之作是 FCN,即全卷积神经网络。此节主要是想通过写一遍代码来捋一遍 U-Net 网络的结构。故此节不对网络结构具体的每部分作深入说明。主要是实现以下网络的每部分结构。
网络的结构图:
总体说明:由于在U-Net网络中,会有多次两次连续卷积的操作,故将其单独写成一个模块。其网络主要由以下这些模块构成:DoubleConv 模块、下采样模块、上采样模块、输出模块。具体实现如下.
import torch.nn as nn import torch.nn.functional as F import torch class DoubleConv(nn.Module): """ 1. DoubleConv 模块 (convolution => [BN] => ReLU) * 2 连续两次的卷积操作:U-net网络中,下采样和上采样过程,每一层都会连续进行两次卷积操作 """ def __init__(self, in_channels, out_channels): super().__init__() # torch.nn.Sequential是一个时序容器,Modules 会以它们传入的顺序被添加到容器中。 # 此处:卷积->BN->ReLU->卷积->BN->ReLU self.double_conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.double_conv(x) class Down(nn.Module): """ 2. Down(下采样)模块 Downscaling with maxpool then double conv maxpool池化层,进行下采样,再接DoubleConv模块 """ def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), # 池化层 DoubleConv(in_channels, out_channels) # DoubleConv模块 ) def forward(self, x): return self.maxpool_conv(x) class Up(nn.Module): """ 3. Up(上采样)模块 Upscaling then double conv """ """ __init__初始化函数定义了上采样方法以及卷积采用DoubleConv 上采样,定义了两种方法:Upsample和ConvTranspose2d,也就是双线性插值和反卷积。 """ def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) # 反卷积(2*2 => 4*4) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): """ x1接收的是上采样的数据,x2接收的是特征融合的数据 特征融合方法就是,先对小的feature map进行padding,再进行concat(通道叠加) :param x1: :param x2: :return: """ x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x = torch.cat([x2, x1], dim=1) return self.conv(x) class OutConv(nn.Module): """ 4. OutConv模块 UNet网络的输出需要根据分割数量,整合输出通道(若最后的通道为2,即分类为2的情况) """ def __init__(self, in_channels, out_channels): super(OutConv, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): return self.conv(x) """ UNet网络用到的模块即以上4个模块 根据UNet网络结构,设置每个模块的输入输出通道个数以及调用顺序 """ class UNet(nn.Module): def __init__(self, n_channels, n_classes, bilinear = False): super(UNet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = DoubleConv(n_channels, 64) self.down1 = Down(64, 128) self.down2 = Down(128, 256) self.down3 = Down(256, 512) self.down4 = Down(512, 1024) self.up1 = Up(1024, 512, bilinear) self.up2 = Up(512, 256, bilinear) self.up3 = Up(256, 128, bilinear) self.up4 = Up(128, 64, bilinear) self.outc = OutConv(64, n_classes) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits if __name__ == '__main__': net = UNet(n_channels=3, n_classes=2) print(net)
结果:
UNet( (inc): DoubleConv( (double_conv): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)) (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU(inplace=True) ) ) (down1): Down( (maxpool_conv): Sequential( (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (1): DoubleConv( (double_conv): Sequential( (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1)) (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU(inplace=True) ) ) ) ) (down2): Down( (maxpool_conv): Sequential( (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (1): DoubleConv( (double_conv): Sequential( (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1)) (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU(inplace=True) ) ) ) ) (down3): Down( (maxpool_conv): Sequential( (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (1): DoubleConv( (double_conv): Sequential( (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1)) (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU(inplace=True) ) ) ) ) (down4): Down( (maxpool_conv): Sequential( (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (1): DoubleConv( (double_conv): Sequential( (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1)) (4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU(inplace=True) ) ) ) ) (up1): Up( (up): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2)) (conv): DoubleConv( (double_conv): Sequential( (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1)) (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU(inplace=True) ) ) ) (up2): Up( (up): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2)) (conv): DoubleConv( (double_conv): Sequential( (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1)) (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU(inplace=True) ) ) ) (up3): Up( (up): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2)) (conv): DoubleConv( (double_conv): Sequential( (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1)) (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU(inplace=True) ) ) ) (up4): Up( (up): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2)) (conv): DoubleConv( (double_conv): Sequential( (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)) (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU(inplace=True) ) ) ) (outc): OutConv( (conv): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1)) ) )
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。