当前位置:   article > 正文

MS-Model【1】:nnU-Net

nnu-net


前言

本文提出的 nnU-netno-new U-net),是在 2D & 3D 经典 U-net 的基础上, 稳健而又自适应的框架。nnU-net 移去了冗余的部分,着重于剩下的对模型表现和泛化能力起作用的部分

原论文链接:
nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation

论文复现参考:
MS-Train【1】:nnUNet


1. Abstract & Introduction

1.1. Abstract

U-Net 凭借其直接和成功的架构,迅速发展成为医学图像分割领域的一个常用基准。然而,U-Net 对新问题的适应性包括关于确切的架构、预处理、训练和推理的几个自由度。这些选择并不是相互独立的,而是对整体性能有很大影响。本文介绍了 nnU-Netno-new-Net),它是指在二维和三维虚构 U-Net 的基础上建立的一个稳健和自适应的框架。本文论证了在许多提议的网络设计中去掉多余的 bellswhistles,转而关注其余的方面,这些方面决定了一个方法的性能和可推广性。

1.2. Introduction

本文提出了 nnU-Netno-new-Net)框架。它基于一组相对简单的 U-Net 模型,只包含对原始U-Net 的微小修改。本文省略了最近提出的扩展,例如使用 residual connectionsdense connectionsattention mechanismsnnU-Net 可以自动将其架构适应给定的图像几何,更重要的是,它彻底定义了围绕他们的所有其他步骤。这些步骤包括:

  • 预处理,比如 resampling 和 normalization
  • 训练,比如损失函数、优化器的设置和数据扩充
  • 推断,比如基于图像块的策略、TTA(test-time augmentation)集成和模型集成
  • 后处理,比如增强单连通域

2. Methods

2.1. Network architectures

U-Net 是一个成功的 encoder-decoder 网络:

  • encoder:工作原理与传统的分类 CNN 类似,以减少空间信息为代价连续聚集语义信息
  • decoder:因为在分割中,语义和空间信息对网络的成功至关重要,所以需要使用 decoder 以恢复丢失的空间信息
    • decoder 从 U 的底部接收语义信息,并将其与通过跳过连接直接从编码器获得的高分辨率特征图重新组合

医学图像通常包括第三维,本文提出了一套包括由 2D U-Net3D U-NetU-Net cascade 组成的基本 U-Net 架构:

  • 2D 和 3D U-Net 以全分辨率生成分割
  • cascade 首先生成低分辨率的分割,然后再对其进行细化
  • 与原始 U-Net 的比较:
    • 和原始U-Net类似,在 encoder 部分,本文在池化层之间使用简单的卷积层,在 decoder 部分,本文使用转置卷积
    • 与原始 U-Net 不同的是,本文使用的激活函数是 leaky ReLU 而不是 ReLU(-1e-2);用 instance normalization 替换更加流行的 batch normalization

2.1.1. 2D U-Net

当数据是各向异性的时候,传统的 3D 分割方法就会很差,所以在这里给出了 2D U-Net 的网络架构

网络特征:

  • 使用全卷积神经网络
    • 全卷积神经网络就是卷积取代了全连接层,全连接层必须固定图像大小而卷积不用,所以这个策略保证使用者可以输入任意尺寸的图片,而且输出也是图片,所以这是一个端到端的网络
  • 左边的网络是收缩路径:使用卷积和 maxpooling.
  • 右边的网络是扩张路径:使用上采样产生的特征图与左侧收缩路径对应层产生的特征图进行 concatenate 操作
  • 最后再经过两次反卷积操作,生成特征图,再用两个 1 × 1 1 \times 1 1×1 的卷积做分类得到最后的两张 heatmap
    • 例如第一张表示的是第一类的得分,第二张表示第二类的得分
    • 然后作为 softmax 函数的输入,算出概率比较大的 softmax 类,选择它作为输入给交叉熵进行反向传播训练

在这里插入图片描述

更多关于 U2-Net 网络结构的讲解,可以参考我的另一篇 blog:SS-Model【6】:U2-Net

2.1.2. 3D U-Net

3D 网络的效果好,但是太占用 GPU 显存。一般情况下,可以使用小一点的图像块去训练,但是当面对比较大的图像如肝等,这种基于块的方法就会阻碍训练。这是因为受限于感受野的大小,网络结构不能收集足够的上下文信息去正确的识别大图像的特征

网络内容:

  • 核心
    • 训练过程只要求一部分 2D slices,去生成密集的立体分割
  • 两种实现方法
    • 在一个稀疏标注的数据集上训练并在此数据集上预测其他未标注的地方
    • 在多个稀疏标注的数据集上训练,然后泛化到新的数据

网络特征:

  • 网络结构的前半部分(analysis path)包含如下卷积操作
    • 每一层神经网络都包含了两个 3 × 3 × 3 3 \times 3 \times 3 3×3×3 的卷积
    • Batch Normalization(为了让网络能更好的收敛)
    • ReLU
    • 下采样: 2 × 2 × 2 2 \times 2 \times 2 2×2×2max_polling,步长 stride = 2
      • 通过在最大池化之前将通道数量加倍来避免瓶颈
  • 网络结构的合成路径(synthesis path)则执行下面的操作
    • 上采样: 2 × 2 × 2 2 \times 2 \times 2 2×2×2,步长 stride = 2
    • 两个正常的卷积操作: 3 × 3 × 3 3 \times 3 \times 3 3×3×3
    • Batch Normalization
    • ReLU
    • 把在 analysis path 上相对应的网络层的结果作为 decoder 的部分输入,这样子做的原因跟 U-Net 博文 中提到的一样,是为了能采集到特征分析中保留下来的高像素特征信息,以便图像可以更好的合成
    • 在最后一层, 1 × 1 × 1 1 \times 1 \times 1 1×1×1 卷积将输出通道的数量减少到 3 个标签
    • 加权 softmax 损失函数将未标记像素的权重设置为零可以仅从已标记像素学习。降低了频繁出现的背景的权重,增加了内管的权重,以达到小管和背景体素对损失的均衡影响。

在这里插入图片描述

2.1.3. U-Net cascade

为了解决 3D U-Net 在大图像尺寸数据集上的缺陷,本文提出了级联模型:

  • 第一级 3D U-Net 在下采样的图像上进行训练,然后将结果上采样到原始的体素spacing。
  • 将上采样的结果作为一个额外的输入通道(one-hot 编码)送入第二级 3D U-Net,并使用基于图像块的策略在全分辨率的图像上进行训练

在这里插入图片描述

2.2. Dynamic adaptation of network topologies

由于输入图像大小的不同,输入图像块大小和每个轴池化操作的数量(同样也是卷积层的数量)必须能够自适应每个数据集去考虑充足的空间信息聚合。除了自适应图像几何,还需要考虑显存的使用。

指导方针是动态平衡 batch size 和网络容量:

  • 将 patch size 初始化为图像大小的中位数
    • 迭代地减少 patch size,同时调整网络拓扑架构(网络深度、池化操作数量、池化操作位置、feature map 的尺寸、卷积核尺寸)
    • 直到网络可以在给定 GPU 的限制下,且 batch 至少是 2 的情况下,可以被 train

网络初始配置:

  • 2D U-Net
    • 图像大小 = 256 × 256 256 \times 256 256×256,batch size = 42,最高层的特征图谱数量 = 30(每个下采样特征图谱数量翻倍)
    • 自动将这些参数调整为每个数据集的中值平面大小(这里使用面内间距最小的平面,对应于最高的分辨率),以便网络有效地训练整个切片
    • 本文将网络配置为沿每个轴池化,直到该轴的特征图谱小于8(但最多不超过6个池化操作)
  • 3D U-Net
    • 图像大小 = 128 × 128 × 128 128 \times 128 \times 128 128×128×128,batch size = 2,最高层的特征图谱数量 = 30
    • 由于显存限制,不去增加图像大小超过 12 8 3 128^3 1283 体素,而是匹配输入图像和数据集中体素中值大小的比率
      • 如果数据集的形状中值比 12 8 3 128^3 1283 小,那就使用形状的中值作为输入的图像大小并且增加 batch size(目的是将体素的数量和 KaTeX parse error: Undefined control sequence: \tiems at position 5: 128 \̲t̲i̲e̲m̲s̲ ̲128 \times 128,batch size 为 2 的体素数量相等)。沿每个轴最多池化5次直到特征图谱大小为8

在这里插入图片描述

2.3. Preprocessing

nnU-Net 的预处理是在没有任何用户干预的情况下执行的

2.3.1. Cropping

所有数据都被裁剪到非零值区域

2.3.2. Resampling

CNN 本身并不理解体素间距。 在医学图像中,不同的扫描仪或不同的采集协议通常会产生具有不同体素间距的数据集

为了使我们的网络能够正确学习空间语义,所有患者都被重新采样到各自数据集的中值体素间距,其中三阶样条插值用于图像数据最近邻居插值用于相应的分割掩码

是否需要经过 U-Net cascade 模型,由以下方法确定:

  • 如果重采样数据的形状中值可以作为 3D U-Net 中的输入图像(batch size = 2)的 4 倍以上,则使用 U-Net cascade 模型,且数据集需要重新采样到较低的分辨率
    • 可以通过将体素间距增加 2 倍来完成(降低分辨率),直到满足上述标准
    • 如果数据集是各向异性的,则首先对较高分辨率的轴进行下采样,直到它们与低分辨率轴匹配,然后才同时对所有轴进行下采样

2.3.3. Normalization

  • 对于 CT 图像,训练集中所有 segmentation mask 中的 value 会被收集,整体的数据集会先被 clip 到 [ 0.5 , 99.5 ] [0.5, 99.5] [0.5,99.5] 百分位,然后通过收集的数据的 mean 和标准差进行 z-score 正则化
    • 需要注意的是,如果因为裁剪减少了病例平均大小的 1/4 或更多,则标准化只在非零元素的 mask 内部进行,并且 mask 外的所有值设为 0
  • 对于 MRI 图像以及其他图像,直接进行 z-score 标准化

2.4. Training Procedure

所有模型都从头开始训练,并在训练集上使用五折交叉验证进行评估

2.4.1. Loss Function

结合 dice 和交叉熵损失来训练网络:
L t o t a l = L d i c e + L C E \mathcal{L}_{total} = \mathcal{L}_{dice} + \mathcal{L}_{CE} Ltotal=Ldice+LCE

  • 对于在全训练集上训练的 3D U-Net(如果不需要 cascade,则是 U-Net cascade 的第一阶段和 3D U-Net),计算 batch 里每个样本的 dice 损失,并计算 batch 中的平均值
  • 对于所有其他网络,将 batch 中的样本(samples)解释为伪体积(volume),并计算批次中所有体素的 dice 损失

对于目前大多数的图像分割任务来说,使用最多评价指标就是 dice 相似系数 (Dice Similarity Coefficient) 。Dice 系数是计算两个样本之间的相似度,即考察两个样本之间重叠的范围,范围通常在0-1之间。

  • 若为1,则证明两个样本完全重合
  • 若为0,则证明两个样本没有相同的像素

计算方法如下:
D i c e ( P , T ) = 2 T P F P + 2 T P + F N Dice(P, T) = \frac{2TP}{FP + 2TP + FN} Dice(P,T)=FP+2TP+FN2TP

其中:

  • TP (True Positive) 为判定为正样本,事实上也是正样本
  • TN (True Negative) 为判定为负样本,事实上也是负样本
  • FP (False Positive ) 为判定为正样本,事实上为负样本
  • FN (False Negative) 为判定为负样本,事实上为正样本

分母即:

  • FP + TP = 所有分类为阳性的样本
  • TP + FN = 真阳 + 假阴 = 所有真的是阳性的样本

dice 的损失函数为:
L d i c e = − 2 ∣ K ∣ ∑ k ∈ K ∑ i ∈ I u i k v i k ∑ i ∈ I u i k + ∑ i ∈ I v i k \mathcal{L}_{dice} = - \frac{2}{|K|} \displaystyle\sum_{k \in K} \frac{\sum_{i \in I}u_i^k v_i^k}{\sum_{i \in I}u_i^k + \sum_{i \in I}v_i^k} Ldice=K2kKiIuik+iIvikiIuikvik

参数含义:

  • u 是网络的 softmax 输出
  • v 是 ground Truth 的 one hot 编码
  • k 为类别数
  • uv 都具有形状 I × K I \times K I×K

2.4.2. Learning Rate

优化器:Adam,初始学习率 3e-4,每个 epoch 有 250 个 batch
学习率调整策略:计算训练集和验证集的指数滑动平均 loss,如果训练集的指数滑动平均 loss 在近 30 个 epoch 内减少不够 5e-3,则学习率衰减 5 倍
训练停止条件:当学习率大于 10-6 且验证集的指数滑动平均 loss 在近 60 个 epoch 内减少不到 5e-3,则终止训练

2.4.3. Data Augmentation

从有限的训练数据训练大型神经网络时,必须特别注意防止过度拟合。 本文在训练期间动态应用了如下所示的多种增强技术来解决这个问题:

  • random rotations
  • random scaling
  • random elastic deformations
  • gamma correction augmentation
  • mirroring

需要注意的是,如果 3D U-Net 的输入图像块尺寸的最大边长是最短边长的两倍以上,这种情况对每个 2 维面做数据增广,然后逐个切片地将其应用于每个样本

U-Net cascade 的第二级接受前一级的输出作为输入的一部分,为了防止强 co-adaptation,应用随机形态学操作(腐蚀、膨胀、开运算、闭运算)去随机移除掉这些分割结果的连通域

2.4.4. Patch Sampling

为了增强网络训练的稳定性,强制每个 batch 中超过 1/3 的样本包含至少一个随机选择的前景


总结

本文提出了用于医疗领域的 nnU-Net 分割框架,该框架直接围绕原始 U-Net 架构构建,并动态调整自身以适应任何给定数据集的细节。 基于本文的假设,即非架构修改可能比最近提出的一些架构修改更强大,该框架的本质是自适应预处理、训练方案和推理的彻底设计。 适应新分割任务所需的所有设计选择均以全自动方式完成,无需手动交互。

参考资料

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/78107
推荐阅读
相关标签
  

闽ICP备14008679号