赞
踩
网络结构如图所示:
代码实现(基于pytorch):
相关包的引入:
from math import sqrt
import torch
from torch import nn
import torch.nn.functional as F
定义卷积块:
定义了两个卷积操作,分别使用大小为3x3的卷积核进行卷积,步长为1,并且对卷积后的输出进行批量归一化(批量归一化的作用),激活函数采用ReLU。使用卷积模块时,需要指明输入通道数(in_channel)和输出通道数(out_channel)。
class Conv_Block(nn.Module): def __init__(self, in_channel, out_channel): super(Conv_Block, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_channel, out_channel, 3, 1, 1), nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True) ) self.conv2 = nn.Sequential( nn.Conv2d(out_channel, out_channel, 3, 1, 1), nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True) ) def forward(self, input): outputs = self.conv1(input) outputs = self.conv2(outputs) return outputs
编码操作(下采样):将卷积模块的输出进行池化处理。
class UnetDown(nn.Module):
def __init__(self, in_channel, out_channel):
super(UnetDown, self).__init__()
self.conv = Conv_Block(in_channel, out_channel)
self.down = nn.MaxPool2d(2, 2, ceil_mode=True)
def forward(self, inputs):
outputs = self.conv(inputs)
outputs = self.down(outputs)
return outputs
解码操作(上采样):这里的上采样操作提出了两种——ConvTranspose2d和UpsamplingBilinear2d,两者的区别见这里。另外,由于要进行拼接操作,所以在拼接前对上采样的输出进行填充,避免拼接出错。
(ps:代码里面的解码操作是先进行上采样,然后拼接数据,最后进行卷积的,但是在UnetModel中的最后一个编码操作后,单独进行了一次卷积操作,最后的网络结构还是没有变的。)
class UnetUp(nn.Module): def __init__(self, in_channel, out_channel, is_deconv=True): super(UnetUp, self).__init__() self.conv = Conv_Block(in_channel, out_channel) if is_deconv: self.up = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2) else: self.up = nn.UpsamplingBilinear2d(scale_factor=2) def forward(self, inputs1, inputs2): outputs2 = self.up(inputs2) offset1 = (outputs2.size()[2] - inputs1.size()[2]) offset2 = (outputs2.size()[3] - inputs1.size()[3]) # pad传入四个元素时,指的是左填充,右填充,上填充,下填充;前两个元素作用在第一四维,后两个元素作用在第三维 padding = [offset2 // 2, (offset2 + 1) // 2, offset1 // 2, (offset1 + 1) // 2] # Skip and concatenate outputs1 = F.pad(inputs1, padding) return self.conv(torch.cat([outputs1, outputs2], 1))
最后定义整个UNet模块:将代码和网络结构的图结合起来看就很容易理解了。
class UnetModel(nn.Module): def __init__(self, n_classes, in_channels, is_deconv): super(UnetModel, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.n_classes = n_classes filters = [64, 128, 256, 512, 1024] self.down1 = UnetDown(self.in_channels, filters[0]) self.down2 = UnetDown(filters[0], filters[1]) self.down3 = UnetDown(filters[1], filters[2]) self.down4 = UnetDown(filters[2], filters[3]) self.center = Conv_Block(filters[3], filters[4]) self.up4 = UnetUp(filters[4], filters[3], self.is_deconv) self.up3 = UnetUp(filters[3], filters[2], self.is_deconv) self.up2 = UnetUp(filters[2], filters[1], self.is_deconv) self.up1 = UnetUp(filters[1], filters[0], self.is_deconv) self.final = nn.Conv2d(filters[0], self.n_classes, 1) def forward(self, inputs, label_dsp_dim): down1 = self.down1(inputs) down2 = self.down2(down1) down3 = self.down3(down2) down4 = self.down4(down3) center = self.center(down4) up4 = self.up1(down4, center) up3 = self.up2(down3, up4) up2 = self.up3(down2, up3) up1 = self.up4(down1, up2) up1 = up1[:, :, 1:1 + label_dsp_dim[0], 1:1 + label_dsp_dim[1]].contiguous() return self.final(up1) # Initialization of parameters def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, sqrt(2. / n)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.ConvTranspose2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, sqrt(2. / n)) if m.bias is not None: m.bias.data.zero_()
总结:UNet 是一种经典的图像分割网络,它通过编码器-解码器结构、跳跃连接和多尺度特征融合等设计,能够在图像分割任务中取得优秀的性能。基于UNet还衍生出了很多网络,例如 U-Net++, ResUNet, Dense U-Net等,接下来就学习它的衍生网络吧,学习大佬是怎么魔改网络的~另外,刚开始写深度学习的代码时,我不知道从何下手,通过学习大佬实现代码的过程,我发现结合两点就能轻松实现代码:1)写代码时结合网络结构的图片,2)百度相关操作的函数。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。