当前位置:   article > 正文

【深度学习】医学图像语义分割最佳方法的全面比较:UNet和UNet++

utils.process_model_params()

作者:Sergey Kolchenko

编译:ronghuaiyang

来源:AI公园

导读

在不同的任务上对比了UNet和UNet++以及使用不同的预训练编码器的效果。

82255bc7e79cfc0b1c678190c3f29053.jpeg

介绍

语义分割是计算机视觉的一个问题,我们的任务是使用图像作为输入,为图像中的每个像素分配一个类。在语义分割的情况下,我们不关心是否有同一个类的多个实例(对象),我们只是用它们的类别来标记它们。有多种关于不同计算机视觉问题的介绍课程,但用一张图片可以总结不同的计算机视觉问题:

ce0f8a08f27671a483f6121bec2200f4.png

语义分割在生物医学图像分析中有着广泛的应用:x射线、MRI扫描、数字病理、显微镜、内窥镜等。https://grand-challenge.org/challenges上有许多不同的有趣和重要的问题有待探索。

从技术角度来看,如果我们考虑语义分割问题,对于N×M×3(假设我们有一个RGB图像)的图像,我们希望生成对应的映射N×M×k(其中k是类的数量)。有很多架构可以解决这个问题,但在这里我想谈谈两个特定的架构,Unet和Unet++。

有许多关于Unet的评论,它如何永远地改变了这个领域。它是一个统一的非常清晰的架构,由一个编码器和一个解码器组成,前者生成图像的表示,后者使用该表示来构建分割。每个空间分辨率的两个映射连接在一起(灰色箭头),因此可以将图像的两种不同表示组合在一起。并且它成功了!

1b80b01acc1c22edb2e54f7248f88060.png

接下来是使用一个训练好的编码器。考虑图像分类的问题,我们试图建立一个图像的特征表示,这样不同的类在该特征空间可以被分开。我们可以(几乎)使用任何CNN,并将其作为一个编码器,从编码器中获取特征,并将其提供给我们的解码器。据我所知,Iglovikov & Shvets 使用了VGG11和resnet34分别为Unet解码器以生成更好的特征和提高其性能。

27762d542ca7758b3fe1f661920a7ae0.png

TernausNet (VGG11 Unet)

Unet++是最近对Unet体系结构的改进,它有多个跳跃连接。

722c7ed582e327b11e9f5d4d2e1ace6c.png

根据论文, Unet++的表现似乎优于原来的Unet。就像在Unet中一样,这里可以使用多个编码器(骨干)来为输入图像生成强特征。

我应该使用哪个编码器?

这里我想重点介绍Unet和Unet++,并比较它们使用不同的预训练编码器的性能。为此,我选择使用胸部x光数据集来分割肺部。这是一个二值分割,所以我们应该给每个像素分配一个类为“1”的概率,然后我们可以二值化来制作一个掩码。首先,让我们看看数据。

6d489d2f16ef313cdac5d18fcfdfa669.png

来自胸片X光数据集的标注数据的例子

这些是非常大的图像,通常是2000×2000像素,有很大的mask,从视觉上看,找到肺不是问题。使用segmentation_models_pytorch库,我们为Unet和Unet++使用100+个不同的预训练编码器。我们做了一个快速的pipeline来训练模型,使用Catalyst (pytorch的另一个库,这可以帮助你训练模型,而不必编写很多无聊的代码)和Albumentations(帮助你应用不同的图像转换)。

  1. 定义数据集和增强。我们将调整图像大小为256×256,并对训练数据集应用一些大的增强。

  1. import albumentations as A
  2. from torch.utils.data import Dataset, DataLoader
  3. from collections import OrderedDict
  4. class ChestXRayDataset(Dataset):
  5.     def __init__(
  6.         self,
  7.         images,
  8.         masks,
  9.             transforms):
  10.         self.images = images
  11.         self.masks = masks
  12.         self.transforms = transforms
  13.     def __len__(self):
  14.         return(len(self.images))
  15.     def __getitem__(self, idx):
  16.         """Will load the mask, get random coordinates around/with the mask,
  17.         load the image by coordinates
  18.         """
  19.         sample_image = imread(self.images[idx])
  20.         if len(sample_image.shape) == 3:
  21.             sample_image = sample_image[..., 0]
  22.         sample_image = np.expand_dims(sample_image, 2) / 255
  23.         sample_mask = imread(self.masks[idx]) / 255
  24.         if len(sample_mask.shape) == 3:
  25.             sample_mask = sample_mask[..., 0]  
  26.         augmented = self.transforms(image=sample_image, mask=sample_mask)
  27.         sample_image = augmented['image']
  28.         sample_mask = augmented['mask']
  29.         sample_image = sample_image.transpose(201)  # channels first
  30.         sample_mask = np.expand_dims(sample_mask, 0)
  31.         data = {'features': torch.from_numpy(sample_image.copy()).float(),
  32.                 'mask': torch.from_numpy(sample_mask.copy()).float()}
  33.         return(data)
  34.     
  35. def get_valid_transforms(crop_size=256):
  36.     return A.Compose(
  37.         [
  38.             A.Resize(crop_size, crop_size),
  39.         ],
  40.         p=1.0)
  41. def light_training_transforms(crop_size=256):
  42.     return A.Compose([
  43.         A.RandomResizedCrop(height=crop_size, width=crop_size),
  44.         A.OneOf(
  45.             [
  46.                 A.Transpose(),
  47.                 A.VerticalFlip(),
  48.                 A.HorizontalFlip(),
  49.                 A.RandomRotate90(),
  50.                 A.NoOp()
  51.             ], p=1.0),
  52.     ])
  53. def medium_training_transforms(crop_size=256):
  54.     return A.Compose([
  55.         A.RandomResizedCrop(height=crop_size, width=crop_size),
  56.         A.OneOf(
  57.             [
  58.                 A.Transpose(),
  59.                 A.VerticalFlip(),
  60.                 A.HorizontalFlip(),
  61.                 A.RandomRotate90(),
  62.                 A.NoOp()
  63.             ], p=1.0),
  64.         A.OneOf(
  65.             [
  66.                 A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
  67.                 A.NoOp()
  68.             ], p=1.0),
  69.     ])
  70. def heavy_training_transforms(crop_size=256):
  71.     return A.Compose([
  72.         A.RandomResizedCrop(height=crop_size, width=crop_size),
  73.         A.OneOf(
  74.             [
  75.                 A.Transpose(),
  76.                 A.VerticalFlip(),
  77.                 A.HorizontalFlip(),
  78.                 A.RandomRotate90(),
  79.                 A.NoOp()
  80.             ], p=1.0),
  81.         A.ShiftScaleRotate(p=0.75),
  82.         A.OneOf(
  83.             [
  84.                 A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
  85.                 A.NoOp()
  86.             ], p=1.0),
  87.     ])
  88. def get_training_trasnforms(transforms_type):
  89.     if transforms_type == 'light':
  90.         return(light_training_transforms())
  91.     elif transforms_type == 'medium':
  92.         return(medium_training_transforms())
  93.     elif transforms_type == 'heavy':
  94.         return(heavy_training_transforms())
  95.     else:
  96.         raise NotImplementedError("Not implemented transformation configuration")
  1. 定义模型和损失函数。这里我们使用带有regnety_004编码器的Unet++,并使用RAdam + Lookahed优化器使用DICE + BCE损失之和进行训练。

  1. import torch
  2. import segmentation_models_pytorch as smp
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from catalyst import dl, metrics, core, contrib, utils
  6. import torch.nn as nn
  7. from skimage.io import imread
  8. import os
  9. from sklearn.model_selection import train_test_split
  10. from catalyst.dl import  CriterionCallback, MetricAggregationCallback
  11. encoder = 'timm-regnety_004'
  12. model = smp.UnetPlusPlus(encoder, classes=1, in_channels=1)
  13. #model.cuda()
  14. learning_rate = 5e-3
  15. encoder_learning_rate = 5e-3 / 10
  16. layerwise_params = {"encoder*": dict(lr=encoder_learning_rate, weight_decay=0.00003)}
  17. model_params = utils.process_model_params(model, layerwise_params=layerwise_params)
  18. base_optimizer = contrib.nn.RAdam(model_params, lr=learning_rate, weight_decay=0.0003)
  19. optimizer = contrib.nn.Lookahead(base_optimizer)
  20. scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=10)
  21. criterion = {
  22.     "dice": DiceLoss(mode='binary'),
  23.     "bce": nn.BCEWithLogitsLoss()
  24. }
  1. 定义回调函数并训练!

  1. callbacks = [
  2.     # Each criterion is calculated separately.
  3.     CriterionCallback(
  4.        input_key="mask",
  5.         prefix="loss_dice",
  6.         criterion_key="dice"
  7.     ),
  8.     CriterionCallback(
  9.         input_key="mask",
  10.         prefix="loss_bce",
  11.         criterion_key="bce"
  12.     ),
  13.     # And only then we aggregate everything into one loss.
  14.     MetricAggregationCallback(
  15.         prefix="loss",
  16.         mode="weighted_sum"
  17.         metrics={
  18.             "loss_dice"1.0
  19.             "loss_bce"0.8
  20.         },
  21.     ),
  22.     # metrics
  23.     IoUMetricsCallback(
  24.         mode='binary'
  25.         input_key='mask'
  26.     )
  27.     
  28. ]
  29. runner = dl.SupervisedRunner(input_key="features", input_target_key="mask")
  30. runner.train(
  31.     model=model,
  32.     criterion=criterion,
  33.     optimizer=optimizer,
  34.     scheduler=scheduler,
  35.     loaders=loaders,
  36.     callbacks=callbacks,
  37.     logdir='../logs/xray_test_log',
  38.     num_epochs=100,
  39.     main_metric="loss",
  40.     minimize_metric=True,
  41.     verbose=True,
  42. )

如果我们用不同的编码器对Unet和Unet++进行验证,我们可以看到每个训练模型的验证质量,并总结如下:

6e8588a137fda22c362c0637b5525f6f.png

Unet和Unet++验证集分数

我们注意到的第一件事是,在所有编码器中,Unet++的性能似乎都比Unet好。当然,有时这种差异并不是很大,我们不能说它们在统计上是否完全不同 —— 我们需要在多个folds上训练,看看分数分布,单点不能证明任何事情。第二,resnest200e显示了最高的质量,同时仍然有合理的参数数量。有趣的是,如果我们看看https://paperswithcode.com/task/semantic-segmentation,我们会发现resnest200在一些基准测试中也是SOTA。

好的,但是让我们用Unet++和Unet使用resnest200e编码器来比较不同的预测。

026693bded587daf12725d0377cb478b.png

Unet和Unet++使用resnest200e编码器的预测。左图显示了两种模型的预测差异

在某些个别情况下,Unet++实际上比Unet更糟糕。但总的来说似乎更好一些。

一般来说,对于分割网络来说,这个数据集看起来是一个容易的任务。让我们在一个更难的任务上测试Unet++。为此,我使用PanNuke数据集,这是一个带标注的组织学数据集(205,343个标记核,19种不同的组织类型,5个核类)。数据已经被分割成3个folds。

20f0b5eaa04bb85cdbc62e34f4749f06.png

PanNuke样本的例子

我们可以使用类似的代码在这个数据集上训练Unet++模型,如下所示:

验证集上的Unet++得分

我们在这里看到了相同的模式 - resnest200e编码器似乎比其他的性能更好。我们可以用两个不同的模型(最好的是resnest200e编码器,最差的是regnety_002)来可视化一些例子。

resnest200e和regnety_002的预测

我们可以肯定地说,这个数据集是一项更难的任务 —— 不仅mask不够精确,而且个别的核被分配到错误的类别。然而,使用resnest200e编码器的Unet++仍然表现很好。

总结

这不是一个全面语义分割的指导,这更多的是一个想法,使用什么来获得一个坚实的基线。有很多模型、FPN,DeepLabV3, Linknet与Unet有很大的不同,有许多Unet-like架构,例如,使用双编码器的Unet,MAnet,PraNet,U²-net — 有很多的型号供你选择,其中一些可能在你的任务上表现的比较好,但是,一个坚实的基线可以帮助你从正确的方向上开始。

英文原文:https://towardsdatascience.com/the-best-approach-to-semantic-segmentation-of-biomedical-imagesbbe4fd78733f             

 
 
 
 
 
 

4e08d48fcd6bce8b935c3d66f56df3a0.jpeg

 
 
 
 
 
 
 
 
  1. 往期精彩回顾
  2. 适合初学者入门人工智能的路线及资料下载(图文+视频)机器学习入门系列下载机器学习及深度学习笔记等资料打印《统计学习方法》的代码复现专辑机器学习交流qq群955171419,加入微信群请扫码
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/346239
推荐阅读
相关标签
  

闽ICP备14008679号