赞
踩
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class UNetConvBlock(nn.Module):
- def __init__(self, in_chans, out_chans, padding, batch_norm):
- super(UNetConvBlock, self).__init__()
- block=[]
-
- block.append(nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=int(padding))
- block.append(nn.ReLU())
-
- if batch_norm :
- block.append(nn.BatchNorm2d(out_chans))
-
- block.append(nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=int(padding))
- block.append(nn.ReLU())
-
- if batch_norm:
- block.append(nn.BatchNorm2d(out_chans))
-
- self.block = nn.Sequential(*block)
-
- def forward(self, x):
- out = self.block(x)
- return out
-
这里实现的就是,每一个stage的卷积block。如下图:
- class UNetUpBlock(nn.Module):
- def __init__(self, in_chans, out_chans, up_mode, padding, batch_norm):
- super(UNetUpBlock, self).__init__()
- if up_mode == 'upconv':
- self.up = nn.ConvTransposed2d(in_chans, out_chans, kernel_size=2, stride=2)
- elif up_mode=='upsample':
- self.up == nn.Sequential(
- nn.Upsample(mode='bilinear', scale_factor=2),
- nn.Conv2d(in_chans, out_chans, kernel_size=1),
- )
- self.conv_block = UNetConvBlock(in_chans, out_chans, padding, batch_norm)
上采样有两种方式,转置卷积和双线性插值。这里可以选择,使用哪种方式实现。
- def centre_crop(self, layer, target_size):
- _,_,layer_height, layer_width = layer.size()
- diff_y = (layer_height - target_size[0]) // 2
- diff_x = (layer_width - target_size[1]) // 2
- return layer[:, :, diff_y: (diff_y + target_size[0]), diff_x: (diff_x + target_size[1])]
这里实现的是剪裁操作,我们注意到,skip connection两边图像大小是不一样的,根据论文描述,我们需要将Encoder部分的图像剪裁到Decoder部分大小,如图所示:
- def forward(self, x, bridge):
- up = self.up(x)
- crop1 = self.centre_crop(bridge, up.shape[2:])
- out = torch.cat([up, crop1], 1)
- out = self.conv_block(out)
- return out
- class UNet(nn.Module):
- def __init__(
- self,
- in_channels=1,
- n_classes=2,
- depth=5,
- wf=6,
- padding=False,
- batch_norm=False,
- up_mode='upconv'
- ):
- super(UNet, self).__init__()
- assert up_mode in ('upconv', 'upsample')
- self.padding = padding
- self.depth = depth
- prev_channels = in_channels
-
- self.down_path = nn.ModuleList()
-
- for i in range(depth): # 0 1 2 3 4
- self.down_path.append(UNetConvBlock(prev_channels, 2**(wf + i), padding, batch_norm)
- prev_channels=2**(wf+i) # 这里wf+i计算channels数量
-
- self.up_path = nn.ModuleList()
-
- for i in resersed(range(depth-1)):
- self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode, padding, batch_norm)
- prev_channels = 2**(wf+i)
-
- self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)
-
- def forward(self, x):
- blocks=[]
- for i, down in enumerate(self.down_path):
- x = down(x)
- if i != len(self.down_path) - 1:
- blocks.append(x)
- x = F.max_pool2d(x ,2)
- for i, up in enumerate(self.up_path):
- x = up(x, blocks[-i -1])
- return self.last(x)
-
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。