赞
踩
UNet架构该架构看起来像一个"U"。该体系结构由三部分组成:contraction,bottleneck和expansion 部分。contraction部分由许多contraction块组成。每个块接受一个输入,应用两个3X3的卷积层,然后是一个2X2的最大池化。在每个块之后,核或特征映射的数量会加倍,这样体系结构就可以有效地学习复杂的结构。最底层介于contraction层和expansion 层之间。它使用两个3X3 CNN层,然后是2X2 up convolution层。这种架构的核心在于expansion 部分。与contraction层类似,它也包含几个expansion 块。每个块将输入传递到两个3X3 CNN层,然后是2X2上采样层。此外,卷积层使用的每个块的feature map数量得到一半,以保持对称性。每次输入也被相应的收缩层的 feature maps所附加。这个动作将确保在contracting 图像时学习到的特征将被用于重建图像。expansion 块的数量与contraction块的数量相同。之后,生成的映射通过另一个3X3 CNN层,feature map的数量等于所需的segment的数量。
UNet中的损失计算UNet对每个像素使用了一种新颖的损失加权方案,使得分割对象的边缘具有更高的权重。这种损失加权方案帮助U-Net模型以不连续的方式分割生物医学图像中的细胞,以便在binary segmentation map中容易识别单个细胞。首先,在所得图像上应用pixel-wise softmax,然后是交叉熵损失函数。所以我们将每个像素分类为一个类。我们的想法是,即使在分割中,每个像素都必须存在于某个类别中,我们只需要确保它们可以。因此,我们只是将分段问题转换为多类分类问题,与传统的损失函数相比,它表现得非常好。
UNet实现的Python代码Python代码如下:import torchfrom torch import nnimport torch.nn.functional as Fimport torch.optim as optimclass UNet(nn.Module): def contracting_block(self, in_channels, out_channels, kernel_size=3): block = torch.nn.Sequential( torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels), torch.nn.ReLU(), torch.nn.BatchNorm2d(out_channels), torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels), torch.nn.ReLU(), torch.nn.BatchNorm2d(out_channels), ) return block def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3): block = torch.nn.Sequential( torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) ) return block def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3): block = torch.nn.Sequential( torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1), torch.nn.ReLU(), torch.nn.BatchNorm2d(out_channels), ) return block def __init__(self, in_channel, out_channel): super(UNet, self).__init__() #Encode self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64) self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2) self.conv_encode2 = self.contracting_block(64, 128) self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2) self.conv_encode3 = self.contracting_block(128, 256) self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2) # Bottleneck self.bottleneck = torch.nn.Sequential( torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512), torch.nn.ReLU(), torch.nn.BatchNorm2d(512), torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512), torch.nn.ReLU(), torch.nn.BatchNorm2d(512), torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1) ) # Decode self.conv_decode3 = self.expansive_block(512, 256, 128) self.conv_decode2 = self.expansive_block(256, 128, 64) self.final_layer = self.final_block(128, 64, out_channel) def crop_and_concat(self, upsampled, bypass, crop=False): if crop: c = (bypass.size()[2] - upsampled.size()[2]) // 2 bypass = F.pad(bypass, (-c, -c, -c, -c)) return torch.cat((upsampled, bypass), 1) def forward(self, x): # Encode encode_block1 = self.conv_encode1(x) encode_pool1 = self.conv_maxpool1(encode_block1) encode_block2 = self.conv_encode2(encode_pool1) encode_pool2 = self.conv_maxpool2(encode_block2) encode_block3 = self.conv_encode3(encode_pool2) encode_pool3 = self.conv_maxpool3(encode_block3) # Bottleneck bottleneck1 = self.bottleneck(encode_pool3) # Decode decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True) cat_layer2 = self.conv_decode3(decode_block3) decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True) cat_layer1 = self.conv_decode2(decode_block2) decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True) final_layer = self.final_layer(decode_block1) return final_layer
以上Python代码中的UNet模块代表了UNet的整体架构。使用contracaction_block和expansive_block分别创建contraction部分和expansion部分。crop_and_concat函数的作用是将contraction层的输出添加到新的expansion层输入中。训练部分的Python代码可以写成unet = Unet(in_channel=1,out_channel=2)#out_channel represents number of segments desiredcriterion = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD(unet.parameters(), lr = 0.01, momentum=0.99)optimizer.zero_grad() outputs = unet(inputs)# permute such that number of desired segments would be on 4th dimensionoutputs = outputs.permute(0, 2, 3, 1)m = outputs.shape[0]# Resizing the outputs and label to caculate pixel wise softmax lossoutputs = outputs.resize(m*width_out*height_out, 2)labels = labels.resize(m*width_out*height_out)loss = criterion(outputs, labels)loss.backward()optimizer.step()
结论图像分割是一个重要的问题,每天都有一些新的研究论文发表。UNet在这类研究中做出了重大贡献。许多新架构的灵感都来自UNet。在业界,这种体系结构有很多变体,因此有必要理解第一个变体,以便更好地理解它们。本文仅代表作者个人观点,不代表巅云官方发声,对观点有疑义请先联系作者本人进行修改,若内容非法请联系平台管理员,邮箱2522407257@qq.com。更多相关资讯,请到巅云www.yinxi.net学习互联网营销技术请到巅云建站www.yx10011.com。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。