当前位置:   article > 正文

UNet代码详解_conv_block

conv_block

UNet代码详解

第一步,还是加载一些库

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F

创建一个卷积Block类

  1. class UNetConvBlock(nn.Module):
  2. def __init__(self, in_chans, out_chans, padding, batch_norm):
  3. super(UNetConvBlock, self).__init__()
  4. block=[]
  5. block.append(nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=int(padding))
  6. block.append(nn.ReLU())
  7. if batch_norm :
  8. block.append(nn.BatchNorm2d(out_chans))
  9. block.append(nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=int(padding))
  10. block.append(nn.ReLU())
  11. if batch_norm:
  12. block.append(nn.BatchNorm2d(out_chans))
  13. self.block = nn.Sequential(*block)
  14. def forward(self, x):
  15. out = self.block(x)
  16. return out

这里实现的就是,每一个stage的卷积block。如下图:

 创建上采样的Block

  1. class UNetUpBlock(nn.Module):
  2. def __init__(self, in_chans, out_chans, up_mode, padding, batch_norm):
  3. super(UNetUpBlock, self).__init__()
  4. if up_mode == 'upconv':
  5. self.up = nn.ConvTransposed2d(in_chans, out_chans, kernel_size=2, stride=2)
  6. elif up_mode=='upsample':
  7. self.up == nn.Sequential(
  8. nn.Upsample(mode='bilinear', scale_factor=2),
  9. nn.Conv2d(in_chans, out_chans, kernel_size=1),
  10. )
  11. self.conv_block = UNetConvBlock(in_chans, out_chans, padding, batch_norm)

上采样有两种方式,转置卷积和双线性插值。这里可以选择,使用哪种方式实现。

  1. def centre_crop(self, layer, target_size):
  2. _,_,layer_height, layer_width = layer.size()
  3. diff_y = (layer_height - target_size[0]) // 2
  4. diff_x = (layer_width - target_size[1]) // 2
  5. return layer[:, :, diff_y: (diff_y + target_size[0]), diff_x: (diff_x + target_size[1])]

 这里实现的是剪裁操作,我们注意到,skip connection两边图像大小是不一样的,根据论文描述,我们需要将Encoder部分的图像剪裁到Decoder部分大小,如图所示:

  1. def forward(self, x, bridge):
  2. up = self.up(x)
  3. crop1 = self.centre_crop(bridge, up.shape[2:])
  4. out = torch.cat([up, crop1], 1)
  5. out = self.conv_block(out)
  6. return out

 

 创建UNet

  1. class UNet(nn.Module):
  2. def __init__(
  3. self,
  4. in_channels=1,
  5. n_classes=2,
  6. depth=5,
  7. wf=6,
  8. padding=False,
  9. batch_norm=False,
  10. up_mode='upconv'
  11. ):
  12. super(UNet, self).__init__()
  13. assert up_mode in ('upconv', 'upsample')
  14. self.padding = padding
  15. self.depth = depth
  16. prev_channels = in_channels
  17. self.down_path = nn.ModuleList()
  18. for i in range(depth): # 0 1 2 3 4
  19. self.down_path.append(UNetConvBlock(prev_channels, 2**(wf + i), padding, batch_norm)
  20. prev_channels=2**(wf+i) # 这里wf+i计算channels数量
  21. self.up_path = nn.ModuleList()
  22. for i in resersed(range(depth-1)):
  23. self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode, padding, batch_norm)
  24. prev_channels = 2**(wf+i)
  25. self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)
  26. def forward(self, x):
  27. blocks=[]
  28. for i, down in enumerate(self.down_path):
  29. x = down(x)
  30. if i != len(self.down_path) - 1:
  31. blocks.append(x)
  32. x = F.max_pool2d(x ,2)
  33. for i, up in enumerate(self.up_path):
  34. x = up(x, blocks[-i -1])
  35. return self.last(x)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/87945
推荐阅读
相关标签
  

闽ICP备14008679号