当前位置:   article > 正文

论文阅读:MamMIL: Multiple Instance Learning for Whole Slide Images with State Space Models

论文阅读:MamMIL: Multiple Instance Learning for Whole Slide Images with State Space Models

论文介绍

这是一篇发表在arXiv上的一篇论文,主要讲述的是新的基础模型框架Mamba出来之后,将其应用在WSI分类中的工作。
论文地址:https://arxiv.org/pdf/2403.05160.pdf
代码:暂未开源
在这里插入图片描述

摘要

最近,通过使用整个全视野数字切片图像 (WSI) 将 Transformer 与多实例学习 (MIL) 框架相结合,作为癌症诊断黄金标准的病理诊断取得了卓越的性能。然而,WSIs的千兆像素性质对Transformer中的二次复杂度的自注意机制在MIL中的应用提出了很大的挑战。现有的研究通常使用线性注意来提高计算效率,但不可避免地会带来性能瓶颈。为了应对这一挑战,我们提出了一种 MamMIL 框架,通过首次将选择性结构化状态空间模型(即 Mamba)与MIL结合进行 WSI 分类,从而能够在保持线性复杂度的同时对实例相关性进行建模。具体来说,为了解决 Mamba 只能进行单向一维(1D)序列建模的问题,我们创新性地引入了一个双向状态空间模型和一个 2D 上下文感知块,以使 MamMIL 能够学习具有 2D 空间关系的双向实例依赖关系。在两个数据集上的实验表明,与基于Transformer的最先进的MIL框架相比,MamMIL可以在更小的内存占用下实现更先进的分类性能。如果论文被接受,代码将开源。

引言

病理作为癌症诊断的金标准,自全视野数字切片图像(whole slide images, WSI)出现以来,经历了从人工观察到数字化分析的新飞跃[1,7]。随着深度学习的发展,一个新的跨学科领域——计算病理学应运而生,并引起了现代医疗领域的高度关注[24]。在计算病理学中,基于深度学习的模型被开发出来对wsi进行分类进行自动诊断,这在很大程度上减轻了病理学家的工作负担,减轻了诊断过程中的主观性[13,4]。

然而,WSIs由几十亿像素组成,这极大地挑战了计算病理学。一方面,WSIs中的巨大像素阻碍了精细注释的获取,这不仅需要专门的知识,而且非常耗时和昂贵。因此,通常只能获得WSI级别的标签[20]。另一方面,由于图形处理单元 (GPU) 中的内存限制,将庞大的 WSI 输入深度学习模型通常是不可行的。

为了解决上述问题,研究人员最近更加关注多实例学习(MIL)来计算病理学[9]。MIL将每个WSI视为一个包,由从WSI中分离出来的小补丁组成,称为实例。在 MIL 中,特征学习不再在千兆像素 WSI 上执行。相反,对每个小实例进行特征提取,从而解决了 GPU 内存限制的问题。在实例特征提取之后,MIL 聚合所有实例特征以获得袋子特征,该特征可以由 WSI 级别的标签监督,从而实现模型训练过程。

由于实例聚合直接影响包特征的可辨别性,从而决定了模型的性能,因此许多 MIL 框架都专注于聚合过程。例如,Ilse等人[14]提出了ABMIL,它利用门控注意机制进行实例聚合。为了避免错误注意力导致的性能下降,Shi 等人 [23] 提出了 LossAttn 框架。Lu等人[19]提出了实例聚类约束的CLAM,以缓解模型对大规模数据的要求。为了缓解训练WSIs数量有限带来的过拟合问题,Zhang等人[27]提出了一种伪包策略,设计了DTFDMIL。为了解决负实例和正实例不平衡引起的决策边界转移问题,Li等人[15]提出了DSMIL框架。

然而,上述所有研究都假设实例是相互独立的,即遵循独立且相同的分布。相反,组织间的相互作用是肿瘤进展的关键。出于这个原因,Shao等人[22]提出了TransMIL,它利用Transformer[25]和自注意力来建模成对的实例依赖关系。然而,直接将具有数万个实例的千兆像素wsi提供给Transformer通常会导致内存不足问题,这是由于自注意力的二次复杂性,从而阻碍了模型训练过程。为此,基于transformer的MIL方法通常使用近似为线性注意力,例如Nyströmformer[26],作为自注意力。然而,研究表明,与自注意力相比,线性注意的表现不可避免地受到限制[6]。

为了解决基于transformer的MIL方法中的问题,一些研究人员将状态空间模型(ssm)应用到MIL框架中,因为它们能够对具有线性或近线性复杂性的长序列进行建模。例如,Fillioux等人[8]使用了一种名为S4D的SSM[11]进行MIL,然而,S4D的参数是输入不变的,这阻碍了模型关注最具判别性的实例,从而导致性能较差[10]。此外,为了获得满意的结果,S4D必须使用复值参数,这不可避免地增加了GPU的内存占用。

最近,出现了一种新的SSM,即Mamba[10]。与之前受到性能限制和内存占用增加的困扰的SSM相比,Mamba将SSM的参数扩展为具有真实值的依赖于内容的方式。现有研究证实,在自然语言处理和视觉任务方面,Mamba可以达到与Transformer相当或更好的性能[10,28,18,21]。然而,没有研究利用Mamba进行WSI分类或将其与MIL框架相结合。以上原因促使我们将Mamba应用于MIL进行WSI的自动病理诊断。

然而,在将Mamba应用于MIL进行WSI分类时存在一些挑战。首先,Mamba是建立一维(1D)序列。当将二维WSI转为一维序列作为输入时,不可避免地会出现二维空间信息的丢失。其次,Mamba使用 “扫描” 策略以单向的方式计算隐藏状态。虽然单向扫描对具有时间序列特征的序列建模是可行的,但对于具有双向对依赖关系的WSI来说,单向扫描可能效率低下。为了充分利用Mamba的长序列建模能力,同时解决上述挑战,我们提出了一个MIL框架,称为MamMIL。具体而言,本文的主要贡献如下:

  • 作为与Transformer性能相当的线性复杂度模型,本文首次将Mamba应用于WSI分类任务的MIL框架,其中每个包由数万个实例组成。
  • 为了解决Mamba的单向建模问题,我们引入了双向SSM块。此外,采用二维上下文感知块来避免一维序列中二维空间信息的丢失。
  • 在两个数据集上的实验证明,在获得比基于Transformer的MIL框架更低的GPU内存占用的同时,Mamba优于SOTA方法。我们的工作为未来的MIL研究提供了新的架构和方向。

方法

框架总览

图1给出了Mamba的概述,主要包括实例特征提取阶段和实例特征聚合阶段。在第一阶段,使用滑动窗口将WSI分割成小的、不重叠的patches作为实例。随后,采用预训练的ResNet50[12]提取实例的特征。在使用可训练的线性层和ReLU激活函数之后,实例特征构成了一个1D序列,同时在最后一个位置附加了一个类token。然后,该序列被送入第二阶段,该阶段由一系列堆叠的MIL-SSM模块组成。每个MIL-SSM模块由一个双向SSM (Bi-SSM)模块和一个2D上下文感知模块(2D-CAB)组成。最后,将类token用作WSI分类的包特征。

在这里插入图片描述

实例特征提取

将一张WSI表示为 X X X, 我们首先将这张WSI划分为M个patches(实例) { p i } i = 1 M \{p_i\}_{i=1}^M {pi}i=1M,其中 p i ∈ R H × W × 3 p_i \in \mathbb{R}^{H \times W \times 3} piRH×W×3。在这里, H H H W W W分别表示为每个实例的高和宽,且 H = W = 256 H=W=256 H=W=256。接下来,我们利用一个在ImageNet[5]上进行预训练的后面有一个可训练的线性层和ReLU激活函数的ResNet50,假设为 ε \varepsilon ε,来提取每个patch的特征,表示为 x i = ε ( p i ) ∈ R 512 x_i = \varepsilon(p_i) \in \mathbb{R}^{512} xi=ε(pi)R512。在提取完特征之后,一张WSI因此可以用一些一维的实例特征来表示,例如 X = { x i } i = 1 M X=\{x_i\}_{i=1}^M X={xi}i=1M。最后,我们在序列的末尾添加了一个可训练的类别标记(token) x c l s x_{cls} xcls,例如, X = [ x 1 , x 2 , ⋯   , x M , x C L S ] X=[x_1,x_2,\cdots,x_M, x_{CLS}] X=[x1,x2,,xM,xCLS]。这些序列将被送入到下一阶段的聚合中。

实例特征聚合

作为对实例依赖关系进行建模以获取包特征的关键阶段,MamMIL中的实例特征聚合阶段由 L L L个堆叠的MIL-SSM模块组成,每个MIL-SSM模块由一个Bi-SSM模块和一个带有残差连接的2D-CAB模块组成。接下来,将描述这两个模块。

Bi-SSM模块 在 MamMIL 中,Bi-SSM模块是一个关键组件,它利用 Mamba快速挖掘大量具有线性复杂度的实例之间的判别依赖关系。Mamba的主要目的是通过隐藏状态 { h i } i = 1 M \{h_i\}_{i=1}^M {hi}i=1M,学习输入的实例序列 X = { x i } i = 1 M X = \{x_i\}_{i=1}^M X={xi}i=1M到输出序列 { y i } i = 1 M \{y_i\}_{i=1}^M {yi}i=1M的映射,其建模为
在这里插入图片描述
在实际中,在具有离散输入和权重的深度学习模型中应用等式(1),需要进行离散化。通常,像在Mamba中一样,通过带有时间步长 Δ \Delta Δ的零阶保持规则将 A ˉ \bar{A} Aˉ B ˉ \bar{B} Bˉ C ˉ \bar{C} Cˉ离散化。具体为:
在这里插入图片描述
其中 A , B , C A,B,C A,B,C Δ \Delta Δ是可学习参数。为了增强上下文感知能力,Mamba根据三个可学习的线性层 l B l_B lB l C l_C lC l Δ l_{\Delta} lΔ,将参数 B B B C C C Δ \Delta Δ与输入序列 X X X相关联。公式为:
在这里插入图片描述
其中, P Δ \mathbf{P_\Delta} PΔ Δ \Delta Δ的可学习参数。

但从公式(1)中可以看出,隐藏态 h i h_i hi只与之前的隐藏态和当前输入有关,因此 h i h_i hi的计算是单向的 “扫描” 方式。然而,WSI中任何实例之间都可能存在依赖关系。为了解决这个问题,受ViM[28]的启发,我们构建了两个SSMs来同时对前向和后向序列方向建模,并构建了Bi-SSM块。对于前向SSM,我们直接将 X X X输入到SSM并得到输出 Y Y Y。对于后向SSM,我们首先翻转实例特征序列,同时将类标记固定在最后一个位置以构造 X ′ X^{'} X。然后,将 X ′ X^{'} X送入另一个SSM以获得输出 Y ′ Y^{'} Y。最后,对 Y ′ Y^{'} Y的实例特征部分进行还原,得到后向SSM的输出。

为了融合两个SSM的输出,我们采用了一种门控机制。对于第 l l l个MIL-SSM模块的输入序列 X ( l ) X^{(l)} X(l),我们首先对 X ( l ) X^{(l)} X(l)进行随机打乱,同时固定类标记的位置以减轻过拟合。然后,将打乱后的 X ( l ) X^{(l)} X(l)送到两个线性层中,得到输出 X ˉ \bar{X} Xˉ Z ˉ \bar{Z} Zˉ。对 X ˉ \bar{X} Xˉ翻转后得到 X ˉ ′ \bar{X}^{'} Xˉ后,分别将 X ˉ \bar{X} Xˉ X ˉ ′ \bar{X}^{'} Xˉ送入一维卷积得到 X X X X ′ X^{'} X,作为前向或后向SSM的输入。同时, Z ˉ \bar{Z} Zˉ由SiLU函数激活,并通过元素乘法对两个SSM输出的平均值进行门控。最后,门控序列被重新打乱,作为Bi-SSM的输出。

2D上下文感知模块 由于Bi-SSM是在一维序列上执行的,因此仍然无法感知实例中的二维空间关系。为了解决这个问题,我们引入了基于金字塔结构的2D-CAB。具体来说,我们首先去除类标记,并将剩余的1D特征序列重塑为2D平方特征映射。如果实例的数量不可整除,我们首先循环填充序列到最近的完全平方。然后,分别用3 × 3、5 × 5和7 × 7深度卷积提取二维空间关系。输出 X ′ ′ X^{''} X′′是通过对带有残差连接的卷积结果求和得到的。最后, X ′ ′ X^{''} X′′被还原成1D序列。通过移除填充并复制回类标记,我们可以获得一个包含2D空间信息作为2D-CAB输出的1D实例序列。

WSI分类与损失函数 在使用 L L L个MIL-SSM模块进行判别性实例依赖挖掘和实例特征聚合之后,我们利用类标记作为WSI分类的包特征,并通过Softmax函数激活线性投影。在实验中,我们使用交叉熵损失进行模型优化。

实验结果

数据集

使用了两个公共WSI数据集Camelyon16[16,2]和BRACS[3]。Camelyon16是一个用于乳腺癌微转移检测的数据集,正式分为270个训练WSI和129个测试WSI。我们进一步以2:1的比例将官方训练集的WSI划分为训练集和验证集,并且官方的测试集中的WSI不进行改动,直接测试。BRCAS是一个乳腺癌亚型数据集,其中每个WSI被分类为良性、非典型和恶性肿瘤。BRCAS正式分为395个wsi的训练集、65个wsi的验证集和87个wsi的测试集。用官方划分进行评估。

实验设置

所有的实验都是在RTX 3090 GPU上完成的。代码由Pytorch 2.1.1实现。在分割WSI时,采用灰度阈值法丢弃背景patch。使用RAdam[17]优化器,固定学习率为1e−4,权重衰减为0.05。每个模型训练250个epoch。使用验证AUC最高的模型进行测试。每个实验的种子固定为4。与之前的大多数研究一样,AUC被用作主要的评价指标。此外,我们报告了准确率和F1-Score,分类阈值为0.5。由于Camelyon16的训练集较小,为了避免过拟合,我们将MIL-SSM模块的个数取 L = 1 L=1 L=1。对于BRCAS,我们取 L = 2 L=2 L=2

结果比较

分类性能比较 MamMIL与其他SOTA方法的比较结果见表1。除了BRCAS的准确性略低于ABMIL外,所提出的MamMIL在所有指标上都超过了所有SOTA方法。此外,我们可以看到,对于BRCAS数据集,MamMIL在AUC上优于SOTA方法,并且比Camelyon16数据集的优势更大。这可能是因为BRCAS数据集中的训练WSI比Camelyon16多。因此,在MamMIL中针对大规模数据提出和优化的SSM块不太可能过拟合。这一现象也表明,如果在大规模数据集上训练,MamMIL可能会表现出更好的性能。
在这里插入图片描述

GPU内存占用比较 除了Fillioux等[8]提出的方法外,现有的MIL框架要么假设WSI中的实例遵循独立分布假设并使用局部注意力进行实例聚合(例如ABMIL和CLAM),要么使用全局自关注(例如TransMIL)对实例依赖关系进行建模。前一种方法需要少量的GPU内存,但它们的性能通常是有限的。后一种方法通常性能更好,但需要更大的GPU内存。相比之下,MamMIL解决了这一冲突,这要归功于Mamba对具有线性复杂性的长序列建模的能力。如图2所示,与TransMIL相比,在Camelyon16和BRCAS上,MamMIL占用的GPU内存分别减少了65.5%和71.7%。此外,MamMIL的内存占用也比基于S4D的方法要小[8],因为在基于S4D的方法中使用了复值参数,而Mamba只需要实值权重。综上所述,Mamba的GPU内存占用与独立分布假设的框架相当,相差不超过1.6 GB。
在这里插入图片描述

消融研究

提出的模块的影响 为了验证所提出的组件的有效性,我们进行了消融研究,结果列于表2。正如我们所看到的,所有被提出的模块都可以提高模型的性能。其中shuffle操作对MamMIL的性能影响最大。这可能是因为shuffle操作可以防止模型在有限的训练数据中记忆固定的模式,从而避免过拟合。此外,同时使用前向和后向SSM可以提高性能,因为它解决了SSM只能单向建模序列的问题。2D-CAB也可以提高模型性能,证明了使MamMIL能够感知实例的2D空间关系的有效性。

类标记位置的影响 考虑到SSM中的隐藏状态是对先前输入实例的压缩[10],为了将所有实例信息融合到类标记中,我们将其添加到特征序列的最后一个位置。进行消融研究以验证其有效性。如表2所示,如果将类标记放在第一个位置,则模型几乎无法训练,因为SSM从开始到结束依次扫描和聚合特征。在中间添加类标记或不使用类标记(使用平均池化来获得包特征)的性能对于两个数据集是可比较的,但两者在AUC中的性能都比我们提出的方法弱。总之,将类标记放在最后一个位置可以在两个数据集中获得更好的整体性能。

在这里插入图片描述

结论

考虑到自注意力的二次方复杂度,当前基于Transformer的MIL框架被迫使用线性注意力进行千兆像素WSI分类,这不可避免地会影响模型的性能。为此,我们提出了MamMIL,它是第一个将Mamba引入MIL的方法,并实现了具有线性复杂性的高效WSI分类。为了解决Mamba只能单向建模序列且不能感知2D空间关系的缺点,我们进一步引入了Bi-SSM模块和2D-CAB模块。实验证明,与基于Transformer的现有方法相比,MamMIL在消耗更少GPU内存的情况下实现了SOTA性能。由于效率高,未来将在大规模数据集上进行训练,并有望表现出更好的性能。

总结(个人观点)

这篇论文将Mamba引入到MIL框架中,从结果上看,获得了不错的表现,并且GPU消耗也比较小。但是感觉这只是一篇占名的文章,即只是把Vim中的双向SSM(Bi-SSM)模块引入到MIL,而文中的2D-CAB模块,则与TransMIL中的PPEG模块非常相似。

更多关于多实例WSI分类的文章阅读,请关注公众号:
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/293443
推荐阅读
  

闽ICP备14008679号