当前位置:   article > 正文

论文阅读 (79):TransMIL: Transformer based Correlated Multiple Instance Learning for Whole Slide Image

transmil

1 概述

1.1 题目

2021:用于WSI分类的Transformer相关多示例 (TransMIL: Transformer based correlated multiple instance learning for whole slide image classification)

1.2 动机

WSI–MIL方法通常基于独立同分布假设,这忽略了不同实例之间的相关性。为了处理这个问题,提出了一个称为相关多示例的新框架。基于该框架,部署了一个基于Transformer的MIL (TransMIL),其能够同时探索形态和空间信息。

TransMIL可视化效果好、可解释性强,能够高效处理不平衡/平滑和二/多分类问题。实验验证了其性能及展示了收敛速度。

图1:决策过程:(a) 独立同分布假设下的注意力机制;(b) 相关多示例框架下的自注意力机制

1.3 代码

Torchhttps://github.com/szc19990412/TransMIL

1.4 附件

https://proceedings.neurips.cc/paper/2021/file/10c272d06794d3e5785d5e7c5356e9ff-Supplemental.pdf

1.5 引用

@article{Shao:2021:21362147,
author		={Zhu Chen Shao and Hao Bian and Yang Chen and Yi Feng Wang and Jian Zhang and Xiang Yang Ji and Yong Bing Zhang},
title		={{TransMIL}: {T}ransformer based correlated multiple instance learning for whole slide image classification},
journal		={Advances in Neural Information Processing Systems},
volume		={34},
pages		={2136--2147},
year		={2021}
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

2 方法

2.1 相关多示例

问题定义:以二分类MIL为例,给定包 X i = { x i , 1 , x i , 2 , … , x i , n } \mathbf{X}_i=\{\boldsymbol{x}_{i,1},\boldsymbol{x}_{i,2},\dots,\boldsymbol{x}_{i,n}\} Xi={xi,1,xi,2,,xi,n},对于 i = 1 , … , b i=1,\dots,b i=1,,b,这样的定义表明包类实例相互依赖且有一定顺序。实例标签 { y i , 1 , y i , 2 , … , y i , n } \{y_{i,1},y_{i,2},\dots,y_{i,n}\} {yi,1,yi,2,,yi,n}是未知的,包标签 Y i ∈ { 0 , 1 } Y_i\in\{0,1\} Yi{0,1}是已知的。一个MIL二分类器可以被定义为:
Y i = { 0 , iff ∑ y i , j = 0 y i , j ∈ { 0 , 1 } , j = 1 … n 1 , otherwise (1) \tag{1} Y_i= \left\{

0,iffyi,j=0yi,j{0,1},j=1n1,otherwise
\right. Yi={0,1,iffyi,j=0yi,j{0,1},j=1notherwise(1) Y ^ i = S ( X i ) (2) \tag{2} \hat{Y}_i=S(\mathbf{X}_i) Y^i=S(Xi)(2)其中 S S S评分函数 Y ^ i \hat{Y}_i Y^i表示预测、 b b b是包的总数、 n n n是包中的实例数,其对不同的包是可变的。

与Attention-net相比,进一步引入实例之间的相关性。定理1推理给出了 S ( X ) S(\mathbf{X}) S(X)的任意形式,定理2说明了相关多示例的一些优势。

定理1 假设 S : X → R S:\mathcal{X}\to\mathbb{R} S:XR是一个关于Hausdirff距离 d H ( ⋅ , ⋅ ) d_H(\cdot,\cdot) dH(,)的连续集合函数,对于任意的可逆图 P : X → R n P:\mathcal{X}\to\mathbb{R}^n P:XRn,存在函数 σ \sigma σ g g g,使得对于任意的 X ∈ X \mathbf{X}\in\mathcal{X} XX有:
∣ S ( X ) − g ( P X ∈ X { σ ( x ∈ X ) } ) ∣ < ϵ (3) \tag{3} |S(\mathbf{X})-g(P_{\mathbf{X}\in\mathcal{X}}\{\sigma(\boldsymbol{x}\in\mathbf{X})\})|<\epsilon S(X)g(PXX{σ(xX)})<ϵ(3)即一个Hausdorff连续函数 S ( X ) S(\mathbf{X}) S(X)能够被 g ( ⋅ ) g(\cdot) g()中的一个函数任意近似。

推理 基于定理1,对于任意的 X \mathbf{X} X有:
∣ S ( X ) − g ( P X ∈ X { f ( x ) + h ( x ) : x ∈ X } ) ∣ < ϵ (7) \tag{7} |S(\mathbf{X})-g(P_{\mathbf{X}\in\mathcal{X}}\{f(\boldsymbol{x})+h(\boldsymbol{x}):\boldsymbol{x}\in\mathbf{X}\})|<\epsilon S(X)g(PXX{f(x)+h(x):xX})<ϵ(7)

定理2:包中的实例可以通过随机变量 Θ 1 , Θ 2 , … , Θ n \Theta_1,\Theta_2,\dots,\Theta_n Θ1,Θ2,,Θn表示,在相关假设下包的信息熵可以被表示为 H ( Θ 1 , Θ 2 , … , Θ n ) H(\Theta_1,\Theta_2,\dots,\Theta_n) H(Θ1,Θ2,,Θn),包在独立同分布 (i.i.d.) 假设下的信息熵可以被表示为 ∑ t = 1 n H ( Θ t ) \sum_{t=1}^nH(\Theta_t) t=1nH(Θt),则有:
H ( Θ 1 , Θ 2 , … , Θ n ) = ∑ t = 2 n H ( Θ t ∣ , Θ 1 , … , Θ t − 1 ) + H ( Θ 1 ) ≤ ∑ t = 1 n H ( Θ t ) (8) \tag{8} H(\Theta_1,\Theta_2,\dots,\Theta_n)=\sum_{t=2}^nH(\Theta_t|,\Theta_1,\dots,\Theta_{t-1})+H(\Theta_1)\leq\sum_{t=1}^nH(\Theta_t) H(Θ1,Θ2,,Θn)=t=2nH(Θt,Θ1,,Θt1)+H(Θ1)t=1nH(Θt)(8)定理2证明了相关假设下有更小的信息熵,其可以减少不确定性和引入更多有用的信息。基于此,算法1展示了相关算法。图2展示了TransMIL与已有方法的主要区别。


  图2:池化矩阵 P \mathbf{P} P的差异:(a) 假设一个WSI中有5个实例, P ∈ R 5 × 5 \mathbf{P}\in\mathbb{R}^{5\times5} PR5×5是相应的池化矩阵,对角线表示和自己的注意力权重,其余的为与其他实例的;(b-d) 均忽略了相关信息,因此 P \mathbf{P} P是对角矩阵;(b) 第一个实例通过最大池化选择,因此只有一个非零值;© 平均池化下对角线的值相等;(d) 注意力的引入使得对角线上的值出现变化;(e) 得益于相关假设,非对角线上的值表明了实例之间的相关性

2.2 Transformer应用到相关MIL

Transformer使用自注意力机制来建模一个序列中的所有token的相关性,并添加位置信息来增加序列顺序信息的有用性。因此,使用函数 h h h来编码所有实例的空间信息,以及 P \mathbf{P} P使用自注意力来汇聚信息的Transformer是很有用的。

Transformer MIL 给定一个包的集合 { X 1 , X 2 , … , X b } \{\mathbf{X}_1,\mathbf{X}_2,\dots,\mathbf{X}_b\} {X1,X2,,Xb}及相应标签 Y i Y_i Yi,目的是习得一种映射 X → T → Y \mathbb{X}\to\mathbb{T}\to\mathcal{Y} XTY,其中 X \mathbb{X} X是包空间、 Y \mathbb{Y} Y是Transformer空间。以及 Y \mathcal{Y} Y是标签空间。

X → T \mathbb{X\to T} XT被定义为:
X i 0 = [ x i , c l a s s ; f ( x i , 1 ) ; f ( x i , 2 ) ; …   ; f ( x i , n ) ] + E p o s , X i 0 , E p o s ∈ R ( n + 1 ) × d (15) \tag{15} \mathbf{X}_i^0=[\boldsymbol{x}_{i,class};f(\boldsymbol{x}_{i,1});f(\boldsymbol{x}_{i,2});\dots;f(\boldsymbol{x}_{i,n})]+\mathbf{E}_{pos},\qquad\mathbf{X}_i^0,\mathbf{E}_{pos}\in\mathbb{R}^{(n+1)\times d} Xi0=[xi,class;f(xi,1);f(xi,2);;f(xi,n)]+Epos,Xi0,EposR(n+1)×d(15) Q ℓ = X i ℓ − 1 W Q , K ℓ = X i ℓ − 1 W K , V ℓ = X i ℓ − 1 W V , ℓ = 1 … L (16) \tag{16} \mathbf{Q}^\ell=\mathbf{X}_{i}^{\ell-1}\mathbf{W}_Q,\quad\mathbf{K}^\ell=\mathbf{X}_{i}^{\ell-1}\mathbf{W}_K,\quad\mathbf{V}^\ell=\mathbf{X}_{i}^{\ell-1}\mathbf{W}_V,\qquad\ell=1\dots L Q=Xi1WQ,K=Xi1WK,V=Xi1WV,=1L(16) h e a d = SA ( Q ℓ , K ℓ , V ℓ ) = softmax ( Q ℓ ( K ℓ ) T d q ) V ℓ , ℓ = 1 … L (17) \tag{17} \mathbf{head}=\text{SA}(\mathbf{Q}^\ell,\mathbf{K}^\ell,\mathbf{V}^\ell)=\text{softmax}\left(\frac{\mathbf{Q}^\ell(\mathbf{K}^\ell)^T}{\sqrt{d_q}}\right)\mathbf{V}^\ell,\qquad\ell=1\dots L head=SA(Q,K,V)=softmax(dq Q(K)T)V,=1L(17) MSA ( Q ℓ , K ℓ , V ℓ ) = Concat ( h e a d 1 , h e a d 2 , … , h e a d h ) W O , ℓ = 1 … L (18) \tag{18} \text{MSA}(\mathbf{Q}^\ell,\mathbf{K}^\ell,\mathbf{V}^\ell)=\text{Concat}(\mathbf{head}_1,\mathbf{head}_2,\dots,\mathbf{head}_h)\mathbf{W}^O,\qquad\ell=1\dots L MSA(Q,K,V)=Concat(head1,head2,,headh)WO,=1L(18) X i ℓ MSA(LN ( X i ℓ − 1 ) ) + X i ℓ − 1 , ℓ = 1 … L (19) \tag{19} \mathbf{X}_i^\ell\text{MSA(LN}(\mathbf{X}^{\ell-1}_i))+\mathbf{X}_i^{\ell-1},\qquad\ell=1\dots L XiMSA(LN(Xi1))+Xi1,=1L(19)其中SA表示自注意力、 L L L是MSA的数量、 h h h是每个MSK中头的数量,以及 L N LN LN是标准化层。

T → Y \mathbb{T}\to\mathcal{Y} TY被定义为:
Y i = MLP(LN ( ( X i L ) ( 0 ) ) ) (20) \tag{20} Y_i=\text{MLP(LN}((\mathbf{X}_i^L)^{(0)})) Yi=MLP(LN((XiL)(0)))(20)其中 ( X i L ) ( 0 ) (\mathbf{X}_i^L)^{(0)} (XiL)(0)表示类别token。 T → Y \mathbb{T}\to\mathcal{Y} TY可以通过类别token或者全局池化完成。然而,目前直接在WSI中使用Transformer相对困难,因此包中的实例数量很多且变化巨大。因此接下来注重如何高效地部署Transformer。

2.3 TransMIL用于弱监督WSI分类

为了更好地描述 X → T \mathbb{X\to T} XT,设计了包含两个Transformer层的TPT模块和一个位置编码层,其中Transformer层用于汇聚形态信息,金字塔位置编码生成器 (Pyramid position encoding generator, PPEG) 用于编码空间信息。TransMIL的总体架构如图3


  图3:TransMIL架构。每个WSI被裁剪为多个区块 (背景被抛弃),并通过ResNet50嵌入为特征向量,然后传递给TPT处理:1) 序列平方;2) 序列的相关性建模;3) 条件位置编码和信息融合;4)深度特征汇聚;5) T → Y \mathbb{T}\to\mathcal{Y} TY映射

2.3.1 使用TPT对长实例序列建模

序列来自于WSI的特征向量。TPT的的处理过程如算法2


大多数情况下,用于视觉任务的Transformer中的softmax是按行处理的。而标准的自注意力机制需要计算每一对toekn之间的相似性得分,太慢太耗内存。为了处理WSI中的长实例序列问题,TPT中的softmax使用Nystrom方法。近似自注意力机制 S ^ \hat{\mathbf{S}} S^被定义为:
S ^ = softmax ( Q K ~ T d q ) ( softmax ( Q ~ K ~ T d q ) ) + softmax ( Q ~ K T d q ) (9) \tag{9} \hat{S}=\text{softmax}\left(\frac{\mathbf{Q}\tilde{\mathbf{K}}^T}{\sqrt{d_q}}\right)\left(\text{softmax}\left(\frac{\tilde{\mathbf{Q}}\tilde{\mathbf{K}}^T}{\sqrt{d_q}}\right)\right)^+\text{softmax}\left(\frac{\tilde{\mathbf{Q}}{\mathbf{K}}^T}{\sqrt{d_q}}\right) S^=softmax(dq QK~T)(softmax(dq Q~K~T))+softmax(dq Q~KT)(9)其中 Q ~ \tilde{\mathbf{Q}} Q~ K ~ \tilde{\mathbf{K}} K~是从 Q \mathbf{Q} Q K \mathbf{K} K中的 n n n维序列中选择的 m m m个landmark,以及 A + \mathbf{A}^+ A+ A \mathbf{A} A的Moore-Penrose伪逆。最终的时间复杂度将从 O ( n 2 ) O(n^2) O(n2)降为 O ( n ) O(n) O(n)。由此,TPT可以满足包中实例很多的情况。

2.3.2 PPEG位置编码

在 WSI 中,由于载玻片和组织的可变大小,相应序列中的标记数量通常会有所不同。有研究表明,添加零填充可以为卷积提供绝对位置信息。受此启发设计了PPEG模块,相应的伪代码如算法3

PPEG模块有以下优势

  1. 同一层使用了不同大小的卷积核,可以编码不同粒度的位置信息,以扩展PPEG的能力;
  2. 可以获取序列中token的全局信息和上下文信息,这能够丰富每个token的特征。

3 实验及结果

3.1 数据集

  • CAMELYON16:用于乳腺癌转移检测的公开数据集,包含270个训练集和130个测试集。预处理后有大约350万个区块,量级为 × 20 \times20 ×20每个包平均有8800个
  • TCGA-NSCLC:包含两个子类,TGCA-LUSC和TCGA-LUAD,共993个诊断WSI,包含444种情形的507个LUAD和452种情况的486个LUSC。预处理后,每个包量级在 × 20 \times20 ×20的区块平均为15371。
  • TCGA-RCC:包含三个子类,THCA-KICH、TCGA-KIRC,以及TCGA-KIRP,共884个WSI,三个子类的情形数分别为99、483,以及264,幻灯片数则111、489,以及284.预处理后平均为14627。

3.2 实验设置和度量指标

  1. 每个WSI的裁剪为 256 × 256 256\times256 256×256的无交叠区块,饱和度 < 15 <15 <15的背景将被抛弃
  2. CAMELYON16的训练集划分10%作为验证集;
  3. TCGA划分时,首先确保训练和测试集中不存在来自一名患者的不同幻灯片,然后训练:验证:测试= 60 : 15 : 25 60:15:25 60:15:25
  4. 准确率和AUC作为评估指标,其中准确率附加计算阈值 0.5 0.5 0.5
  5. CAMELYON16使用测试AUC;
  6. TCGA-NSCLC使用平均AUC;
  7. TCGA-RCC使用macro-averaged AUC;
  8. TCGA使用4折交叉验证。

3.3 实现细节

  1. 交叉熵损失;
  2. Lookahead优化器,学习率 = 2 e − 4 =2e-4 =2e4,权重衰减 = 1 e − 5 =1e-5 =1e5
  3. 批次大小 = 1 =1 =1
  4. 每个区块通过ResNet50嵌入为1024维向量,在训练时通过全连接层降维到 512 512 512
  5. 包的最终嵌入为 H i ∈ R n × 512 \mathbf{H}_i\in\mathbb{R}^{n\times512} HiRn×512
  6. softmax用于标准化每个类别的预测得分。

3.4 基准线

  1. 注意力网络ABMIL和PT-MTA;
  2. 非定位注意力DSMIL;
  3. 单注意力CLAM-SB;
  4. 多头CLAM-MB;
  5. 循环神经网络MIL-RNN。

3.5 结果

分类


消融实验

可视化

收敛性

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/656660
推荐阅读
相关标签
  

闽ICP备14008679号