当前位置:   article > 正文

语义分割系列2-Unet(pytorch实现)_unet pytorch

unet pytorch

Unet发布于MICCAI。其论文的名字也说得相对很明白,用于生物医学图像分割。

U-Net: Convolutional Networks for Biomedical Image Segmentation》

Unet与前文所讲的FCN颇为相似,或者说FCN影响了Unet也影响了之后各类语义分割网络的结构设计。

Unet

网络设计

Unet的网络设计如其名字一般优雅,U型网络。图像数据经过4次下采样,再经过四次上采样恢复到原图大小,同时,每一个上采样层和下采样层之间都有一个跳跃连接(skip connection)。相对FCN来说,这种层层连接的U型架构更加优雅,由于每一次上采样时都融合了对应下采样层的特征,Unet在像素级别的恢复上效果更佳。

而每一层的特征融合后都会经过一系列的卷积层,以此来处理特征图中的细节,让模型学习这些信息来组装一个更精确的输出。

图1 Unet网络结构

trick

作者在设计Unet时也加入了一些tricks来帮助模型训练。

图2 一种重叠的切割策略(Overlap-tile strategy )

原作者将这个策略称为Overlap-tile strategy, 该策略允许通过重叠的方法对任意大的图像进行无缝分割(见图2)。为了预测图像边界区域中的像素,通过镜像输入图像来推断缺失的上下文。这种平铺策略对于将网络应用于大型图像很重要。比如,需要预测图中黄色框的信息,就将蓝色框的数据作为输入,如果蓝色框内有一部分信息缺失,就对蓝色框做镜像处理,获得黄色框区域的上下文信息。

至于为什么要这么做,我认为主要有两个原因:

        一是,作者在原文中提到的,因为需要输入的图像分辨率过大,对GPU的显存占用比较高,这种通过滑窗的预测方式能够在一定程度上减轻GPU的负担。(毕竟是医学图像嘛,往往对图像分辨率要求较高,强行将图像的分辨率resize到比较低的情况下容易损失一些信息)。

        二是,整个Unet的设计中都没有使用padding,因为下采样维度越高,经过越多的卷积层,padding操作越多,越深层的特征图就越容易受到padding的影响,这就导致了图像边缘的损失。但是呢,不使用padding的话,在层层的卷积过程中,图像的分辨率会越来越小,导致最后上采样回去的特征图尺寸和原图不匹配,为了解决这个问题,作者"粗暴"地将原图做一个镜像扩充,这样上采样回去的图像就和原图一样大了。

结果

图3 Unet 在ISBI cell tracking challenge上的结果

复现Unet模型

通过pytorch复现一下Unet模型。

导入模块构建模型

  1. import cv2
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from PIL import Image
  5. import torch
  6. import torch.nn as nn
  7. import torch.optim as optim
  8. import torch.nn.functional as F
  9. from torch.utils.data import Dataset, DataLoader, random_split
  10. from torchvision import transforms
  11. from torchvision import models
  12. from tqdm import tqdm
  13. import warnings
  14. import os.path as osp
  1. import torch
  2. import torch.nn as nn
  3. class Unet(nn.Module):
  4. def __init__(self, num_classes):
  5. super(Unet, self).__init__()
  6. self.stage_1 = nn.Sequential(
  7. nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3,padding=1),
  8. nn.BatchNorm2d(32),
  9. nn.ReLU(),
  10. nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3,padding=1),
  11. nn.BatchNorm2d(64),
  12. nn.ReLU(),
  13. nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3,padding=1),
  14. nn.BatchNorm2d(64),
  15. nn.ReLU(),
  16. )
  17. self.stage_2 = nn.Sequential(
  18. nn.MaxPool2d(kernel_size=2),
  19. nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3,padding=1),
  20. nn.BatchNorm2d(128),
  21. nn.ReLU(),
  22. nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,padding=1),
  23. nn.BatchNorm2d(128),
  24. nn.ReLU(),
  25. )
  26. self.stage_3 = nn.Sequential(
  27. nn.MaxPool2d(kernel_size=2),
  28. nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3,padding=1),
  29. nn.BatchNorm2d(256),
  30. nn.ReLU(),
  31. nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,padding=1),
  32. nn.BatchNorm2d(256),
  33. nn.ReLU(),
  34. )
  35. self.stage_4 = nn.Sequential(
  36. nn.MaxPool2d(kernel_size=2),
  37. nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3,padding=1),
  38. nn.BatchNorm2d(512),
  39. nn.ReLU(),
  40. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3,padding=1),
  41. nn.BatchNorm2d(512),
  42. nn.ReLU(),
  43. )
  44. self.stage_5 = nn.Sequential(
  45. nn.MaxPool2d(kernel_size=2),
  46. nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3,padding=1),
  47. nn.BatchNorm2d(1024),
  48. nn.ReLU(),
  49. nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3,padding=1),
  50. nn.BatchNorm2d(1024),
  51. nn.ReLU(),
  52. )
  53. self.upsample_4 = nn.Sequential(
  54. nn.ConvTranspose2d(in_channels=1024, out_channels=512,kernel_size=4,stride=2, padding=1)
  55. )
  56. self.upsample_3 = nn.Sequential(
  57. nn.ConvTranspose2d(in_channels=512, out_channels=256,kernel_size=4,stride=2, padding=1)
  58. )
  59. self.upsample_2 = nn.Sequential(
  60. nn.ConvTranspose2d(in_channels=256, out_channels=128,kernel_size=4,stride=2, padding=1)
  61. )
  62. self.upsample_1 = nn.Sequential(
  63. nn.ConvTranspose2d(in_channels=128, out_channels=64,kernel_size=4,stride=2, padding=1)
  64. )
  65. self.stage_up_4 = nn.Sequential(
  66. nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3,padding=1),
  67. nn.BatchNorm2d(512),
  68. nn.ReLU(),
  69. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3,padding=1),
  70. nn.BatchNorm2d(512),
  71. nn.ReLU()
  72. )
  73. self.stage_up_3 = nn.Sequential(
  74. nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3,padding=1),
  75. nn.BatchNorm2d(256),
  76. nn.ReLU(),
  77. nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,padding=1),
  78. nn.BatchNorm2d(256),
  79. nn.ReLU()
  80. )
  81. self.stage_up_2 = nn.Sequential(
  82. nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3,padding=1),
  83. nn.BatchNorm2d(128),
  84. nn.ReLU(),
  85. nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,padding=1),
  86. nn.BatchNorm2d(128),
  87. nn.ReLU()
  88. )
  89. self.stage_up_1 = nn.Sequential(
  90. nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3,padding=1),
  91. nn.BatchNorm2d(64),
  92. nn.ReLU(),
  93. nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3,padding=1),
  94. nn.BatchNorm2d(64),
  95. nn.ReLU()
  96. )
  97. self.final = nn.Sequential(
  98. nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=3, padding=1),
  99. )
  100. def forward(self, x):
  101. x = x.float()
  102. #下采样过程
  103. stage_1 = self.stage_1(x)
  104. stage_2 = self.stage_2(stage_1)
  105. stage_3 = self.stage_3(stage_2)
  106. stage_4 = self.stage_4(stage_3)
  107. stage_5 = self.stage_5(stage_4)
  108. #1024->512
  109. up_4 = self.upsample_4(stage_5)
  110. #512+512 -> 512\
  111. up_4_conv = self.stage_up_4(torch.cat([up_4, stage_4], dim=1))
  112. #512 -> 256
  113. up_3 = self.upsample_3(up_4_conv)
  114. #256+256 -> 256
  115. up_3_conv = self.stage_up_3(torch.cat([up_3, stage_3], dim=1))
  116. up_2 = self.upsample_2(up_3_conv)
  117. up_2_conv = self.stage_up_2(torch.cat([up_2, stage_2], dim=1))
  118. up_1 = self.upsample_1(up_2_conv)
  119. up_1_conv = self.stage_up_1(torch.cat([up_1, stage_1], dim=1))
  120. output = self.final(up_1_conv)
  121. return output

可以进行一下简单测试

  1. device = torch.device("cuda:0")
  2. model = Unet(num_classes=2)
  3. model = model.to(device)
  4. a = torch.ones([2, 3, 224, 224])
  5. a = a.to(device)
  6. model(a).shape

 为了方便,本文构建的模型没有按照Unet论文中的镜像填充和重叠的切割策略,用padding来保证上采样和下采样时特征图大小匹配。所以,输出的大小和原图大小应当相等。

 构建Pascal VOC2012数据集

数据集使用了CamVid数据集

  1. # 导入库
  2. import os
  3. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  4. import torch
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. import torch.nn.functional as F
  8. from torch import optim
  9. from torch.utils.data import Dataset, DataLoader, random_split
  10. from tqdm import tqdm
  11. import warnings
  12. warnings.filterwarnings("ignore")
  13. import os.path as osp
  14. import matplotlib.pyplot as plt
  15. from PIL import Image
  16. import numpy as np
  17. import albumentations as A
  18. from albumentations.pytorch.transforms import ToTensorV2
  19. torch.manual_seed(17)
  20. # 自定义数据集CamVidDataset
  21. class CamVidDataset(torch.utils.data.Dataset):
  22. """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
  23. Args:
  24. images_dir (str): path to images folder
  25. masks_dir (str): path to segmentation masks folder
  26. class_values (list): values of classes to extract from segmentation mask
  27. augmentation (albumentations.Compose): data transfromation pipeline
  28. (e.g. flip, scale, etc.)
  29. preprocessing (albumentations.Compose): data preprocessing
  30. (e.g. noralization, shape manipulation, etc.)
  31. """
  32. def __init__(self, images_dir, masks_dir):
  33. self.transform = A.Compose([
  34. A.Resize(224, 224),
  35. A.HorizontalFlip(),
  36. A.VerticalFlip(),
  37. A.Normalize(),
  38. ToTensorV2(),
  39. ])
  40. self.ids = os.listdir(images_dir)
  41. self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
  42. self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
  43. def __getitem__(self, i):
  44. # read data
  45. image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
  46. mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
  47. image = self.transform(image=image,mask=mask)
  48. return image['image'], image['mask'][:,:,0]
  49. def __len__(self):
  50. return len(self.ids)
  51. # 设置数据集路径
  52. DATA_DIR = r'dataset\camvid' # 根据自己的路径来设置
  53. x_train_dir = os.path.join(DATA_DIR, 'train_images')
  54. y_train_dir = os.path.join(DATA_DIR, 'train_labels')
  55. x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
  56. y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
  57. train_dataset = CamVidDataset(
  58. x_train_dir,
  59. y_train_dir,
  60. )
  61. val_dataset = CamVidDataset(
  62. x_valid_dir,
  63. y_valid_dir,
  64. )
  65. train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True,drop_last=True)
  66. val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True,drop_last=True)

模型训练

  1. model = Unet(num_classes=33).cuda()
  2. #model.load_state_dict(torch.load(r"checkpoints/Unet_50.pth"), strict=False)
  1. from d2l import torch as d2l
  2. from tqdm import tqdm
  3. import pandas as pd
  4. #损失函数选用多分类交叉熵损失函数
  5. lossf = nn.CrossEntropyLoss(ignore_index=255)
  6. #选用adam优化器来训练
  7. optimizer = optim.SGD(model.parameters(),lr=0.1)
  8. scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1, last_epoch=-1)
  9. #训练50轮
  10. epochs_num = 100
  11. def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,scheduler,
  12. devices=d2l.try_all_gpus()):
  13. timer, num_batches = d2l.Timer(), len(train_iter)
  14. animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
  15. legend=['train loss', 'train acc', 'test acc'])
  16. net = nn.DataParallel(net, device_ids=devices).to(devices[0])
  17. loss_list = []
  18. train_acc_list = []
  19. test_acc_list = []
  20. epochs_list = []
  21. time_list = []
  22. for epoch in range(num_epochs):
  23. # Sum of training loss, sum of training accuracy, no. of examples,
  24. # no. of predictions
  25. metric = d2l.Accumulator(4)
  26. for i, (features, labels) in enumerate(train_iter):
  27. timer.start()
  28. l, acc = d2l.train_batch_ch13(
  29. net, features, labels.long(), loss, trainer, devices)
  30. metric.add(l, acc, labels.shape[0], labels.numel())
  31. timer.stop()
  32. if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
  33. animator.add(epoch + (i + 1) / num_batches,
  34. (metric[0] / metric[2], metric[1] / metric[3],
  35. None))
  36. test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
  37. animator.add(epoch + 1, (None, None, test_acc))
  38. scheduler.step()
  39. # print(f'loss {metric[0] / metric[2]:.3f}, train acc '
  40. # f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
  41. # print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
  42. # f'{str(devices)}')
  43. print(f"epoch {epoch} --- loss {metric[0] / metric[2]:.3f} --- train acc {metric[1] / metric[3]:.3f} --- test acc {test_acc:.3f} --- cost time {timer.sum()}")
  44. #---------保存训练数据---------------
  45. df = pd.DataFrame()
  46. loss_list.append(metric[0] / metric[2])
  47. train_acc_list.append(metric[1] / metric[3])
  48. test_acc_list.append(test_acc)
  49. epochs_list.append(epoch)
  50. time_list.append(timer.sum())
  51. df['epoch'] = epochs_list
  52. df['loss'] = loss_list
  53. df['train_acc'] = train_acc_list
  54. df['test_acc'] = test_acc_list
  55. df['time'] = time_list
  56. df.to_excel("savefile/Unet_camvid.xlsx")
  57. #----------------保存模型-------------------
  58. if np.mod(epoch+1, 5) == 0:
  59. torch.save(model.state_dict(), f'checkpoints/Unet_{epoch+1}.pth')
train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num,scheduler)

 

总结

大多数医疗影像语义分割任务都会首先用Unet作为baseline,Unet的结构也被称为编码器-解码器结构,即Encoder-Decorer结构,这种结构将会出现在各类语义分割的模型中。

Unet也衍生出一系列家族成员,包括Unet++、attention-Unet、Trans Unet、Swin Unet等等。这些模型也会在之后的系列中更新。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/87926
推荐阅读
相关标签
  

闽ICP备14008679号