赞
踩
论文地址:SAM-Med3D
开源地址:https://github.com/uni-medical/SAM-Med3D
发表日期:2023年10月
参考资料:
作者进行了三维医学图像数据集的广泛收集和标准化工作,整合了116个公开和私有的三维医学图像数据集,经过4轮数据筛选和清晰,创建了迄今为止规模最大的三维医学图像分割数据集。该数据集包含了 2.1 万个三维医学图像(病人数量)和 13.1 万个三维掩码(mask)。从下表可以清晰地看出,这一数据集的规模远远超过了现有最大的三维医学图像分割数据集,如 TotalSegmentator 和 BraTS21,其规模扩大了 10 倍以上。
该数据集涵盖 27 种模态(CT 和 26 种MRI 序列)和 7 种解剖结构。如下图所⽰,共涵盖了 247 个不同的类别,包括器官和病变。
四步数据清洗:
目前SAM-Med3D-turbo是现已发布经过微调的 SAM-Med3D 的最新版本checkpoint。在SAM-Med3D的基础上又在 44 个数据集 ( 以下list )上对其进行了微调以提高性能。
AMOS2022 ATM2022 AbdomenCT1K BTCV_Cervix BraTS2020 BraTS2021 BrainTumour Brain_PTM CAUSE07 CHAOS_Task_4 COSMOS2022 COVID19CTscans CTPelvic1k CT_ORG FLARE21 FLARE22 Heart_Seg_MRI ISLES_SISS ISLES_SPES KiPA22 KiTS KiTS2021 LAScarQS22_task1 LAScarQS22_task2 LITS MMWHS MSD_Colon MSD_HepaticVessel MSD_Liver MSD_Pancreas MSD_Prostate MSD_Spleen PROMISE12 Parse22 Promise09 Prostate_MRI_Segmentation_Dataset SLIVER07 STACOM_SLAWT SegThor Totalsegmentator_dataset VESSEL2012 VerSe19 VerSe20 WORD
基于SAM修改后SAM-Med3D 的 3D 架构。 原始2D组件被转换为3D对应组件,包括3D Image Encoder、3D Prompt Encoder 和3D mask Decoder。采用3D卷积、3D位置编码(PE)和3D layer norm来构建3D模型。
在 3D 图像编码器中,首先使用内核大小为 (16, 16, 16) 的 3D 卷积嵌入块生成embedding,并与可学习的 3D 绝对位置编码 absolute Positional Encoding (PE) 配对。 这种编码是通过自然地将附加维度扩展到 SAM 的 2D PE 来获得的。 然后将补丁的嵌入输入到 3D 注意力块中。 对于 3D 注意力模块,我们将 3D 相关 PE 合并到 SAM 的多头自注意力(MHSA)模块中,使其能够直接捕获空间细节。
class PatchEmbed3D(nn.Module): """ Image to Patch Embedding. """ def __init__( self, kernel_size: Tuple[int, int] = (16, 16, 16), stride: Tuple[int, int] = (16, 16, 16), padding: Tuple[int, int] = (0, 0, 0), in_chans: int = 1, embed_dim: int = 768, ) -> None: """ Args: kernel_size (Tuple): kernel size of the projection layer. stride (Tuple): stride of the projection layer. padding (Tuple): padding size of the projection layer. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. """ super().__init__() self.proj = nn.Conv3d( in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) # B C X Y Z -> B X Y Z C x = x.permute(0, 2, 3, 4, 1) return x
class Attention(nn.Module): """Multi-head Attention block with relative position embeddings.""" def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, input_size: Optional[Tuple[int, int, int]] = None, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. qkv_bias (bool): If True, add a learnable bias to query, key, value. rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. input_size (tuple(int, int) or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.use_rel_pos = use_rel_pos if self.use_rel_pos: assert ( input_size is not None ), "Input size must be provided if using relative positional encoding." # initialize relative positional embeddings self.rel_pos_d = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[2] - 1, head_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: B, D, H, W, _ = x.shape # qkv with shape (3, B, nHead, H * W, C) qkv = self.qkv(x).reshape(B, D * H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # q, k, v with shape (B * nHead, H * W, C) q, k, v = qkv.reshape(3, B * self.num_heads, D * H * W, -1).unbind(0) attn = (q * self.scale) @ k.transpose(-2, -1) if self.use_rel_pos: attn = add_decomposed_rel_pos(attn, q, self.rel_pos_d, self.rel_pos_h, self.rel_pos_w, (D, H, W), (D, H, W)) attn = attn.softmax(dim=-1) x = (attn @ v).view(B, self.num_heads, D, H, W, -1).permute(0, 2, 3, 4, 1, 5).reshape(B, D, H, W, -1) x = self.proj(x) return x
在提示编码器中,稀疏提示利用 3D 位置编码来表示 3D 空间细微差别,而密集提示则通过 3D 卷积进行处理。
class PromptEncoder3D(nn.Module): def __init__( self, embed_dim: int, image_embedding_size: Tuple[int, int, int], input_image_size: Tuple[int, int, int], mask_in_chans: int, activation: Type[nn.Module] = nn.GELU, ) -> None: """ Encodes prompts for input to SAM's mask decoder. Arguments: embed_dim (int): The prompts' embedding dimension image_embedding_size (tuple(int, int)): The spatial size of the image embedding, as (H, W). input_image_size (int): The padded size of the image as input to the image encoder, as (H, W). mask_in_chans (int): The number of hidden channels used for encoding input masks. activation (nn.Module): The activation to use when encoding input masks. """ super().__init__() self.embed_dim = embed_dim self.input_image_size = input_image_size self.image_embedding_size = image_embedding_size self.pe_layer = PositionEmbeddingRandom3D(embed_dim // 3) self.num_point_embeddings: int = 2 # pos/neg point point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] self.point_embeddings = nn.ModuleList(point_embeddings) self.not_a_point_embed = nn.Embedding(1, embed_dim) self.mask_input_size = (image_embedding_size[0], image_embedding_size[1], image_embedding_size[2]) self.mask_downscaling = nn.Sequential( nn.Conv3d(1, mask_in_chans // 4, kernel_size=2, stride=2), LayerNorm3d(mask_in_chans // 4), activation(), nn.Conv3d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), LayerNorm3d(mask_in_chans), activation(), nn.Conv3d(mask_in_chans, embed_dim, kernel_size=1), ) self.no_mask_embed = nn.Embedding(1, embed_dim) def get_dense_pe(self) -> torch.Tensor: """ Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the image encoding. Returns: torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w) """ return self.pe_layer(self.image_embedding_size).unsqueeze(0) # 1xXxYxZ def _embed_points( self, points: torch.Tensor, labels: torch.Tensor, pad: bool, ) -> torch.Tensor: """Embeds point prompts.""" points = points + 0.5 # Shift to center of pixel if pad: padding_point = torch.zeros((points.shape[0], 1, 3), device=points.device) padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) points = torch.cat([points, padding_point], dim=1) labels = torch.cat([labels, padding_label], dim=1) point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) point_embedding[labels == -1] = 0.0 point_embedding[labels == -1] += self.not_a_point_embed.weight point_embedding[labels == 0] += self.point_embeddings[0].weight point_embedding[labels == 1] += self.point_embeddings[1].weight return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: """Embeds box prompts.""" boxes = boxes + 0.5 # Shift to center of pixel coords = boxes.reshape(-1, 2, 2) corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) corner_embedding[:, 0, :] += self.point_embeddings[2].weight corner_embedding[:, 1, :] += self.point_embeddings[3].weight return corner_embedding def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: """Embeds mask inputs.""" mask_embedding = self.mask_downscaling(masks) return mask_embedding def _get_batch_size( self, points: Optional[Tuple[torch.Tensor, torch.Tensor]], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor], ) -> int: """ Gets the batch size of the output given the batch size of the input prompts. """ if points is not None: return points[0].shape[0] elif boxes is not None: return boxes.shape[0] elif masks is not None: return masks.shape[0] else: return 1 def _get_device(self) -> torch.device: return self.point_embeddings[0].weight.device def forward( self, points: Optional[Tuple[torch.Tensor, torch.Tensor]], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Embeds different types of prompts, returning both sparse and dense embeddings. Arguments: points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates and labels to embed. boxes (torch.Tensor or none): boxes to embed masks (torch.Tensor or none): masks to embed Returns: torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined by the number of input points and boxes. torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W) """ bs = self._get_batch_size(points, boxes, masks) sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) if points is not None: coords, labels = points point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) if boxes is not None: box_embeddings = self._embed_boxes(boxes) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) if masks is not None: dense_embeddings = self._embed_masks(masks) else: dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1, 1).expand( bs, -1, self.image_embedding_size[0], self.image_embedding_size[1], self.image_embedding_size[2] ) return sparse_embeddings, dense_embeddings
3D mask Decoder与 3D 上采样集成,采用 3D 转置卷积。
class TwoWayAttentionBlock3D(nn.Module): def __init__( self, embedding_dim: int, num_heads: int, mlp_dim: int = 2048, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, ) -> None: """ A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse inputs. Arguments: embedding_dim (int): the channel dimension of the embeddings num_heads (int): the number of heads in the attention layers mlp_dim (int): the hidden dimension of the mlp block activation (nn.Module): the activation of the mlp block skip_first_layer_pe (bool): skip the PE on the first layer """ super().__init__() self.self_attn = Attention(embedding_dim, num_heads) self.norm1 = nn.LayerNorm(embedding_dim) self.cross_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.norm2 = nn.LayerNorm(embedding_dim) self.mlp = MLPBlock3D(embedding_dim, mlp_dim, activation) self.norm3 = nn.LayerNorm(embedding_dim) self.norm4 = nn.LayerNorm(embedding_dim) self.cross_attn_image_to_token = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.skip_first_layer_pe = skip_first_layer_pe def forward( self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor ) -> Tuple[Tensor, Tensor]: # Self attention block if self.skip_first_layer_pe: queries = self.self_attn(q=queries, k=queries, v=queries) else: q = queries + query_pe attn_out = self.self_attn(q=q, k=q, v=queries) queries = queries + attn_out queries = self.norm1(queries) # Cross attention block, tokens attending to image embedding q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm2(queries) # MLP block mlp_out = self.mlp(queries) queries = queries + mlp_out queries = self.norm3(queries) # Cross attention block, image embedding attending to tokens q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) keys = keys + attn_out keys = self.norm4(keys) return queries, keys
测试了三种训练策略,结果表明从头训练效果最好
对于2D切片分割和3D体积分割,我们从前景中随机采样一个点作为第一个提示,并从误差区域中随机选择以下点。 值得注意的是,2D SAM 方法(SAM、SAM-Med2D)是逐片推断的,而我们的 SAM-Med3D 使用基于补丁的推断方法进行操作。 这与 nnUNet 等最先进的医学图像分割方法一致,赋予 SAM-Med3D 在推理时间方面的优势。 此外,2D方法在推断3D医学图像时对每个切片进行独立交互,而3D方法仅在体积上进行全局交互。 这意味着2D执行的交互次数实际上是3D的N倍(N表示包含对象的切片数量,通常范围为10到200)。 尽管 2D 方法采用了更多的提示点,但其固有的片间交互缺乏造成了明显的性能上限,特别是在相对复杂的 3D 结构上。
在评估阶段,我们选择了 13 个公共基准数据集来审查各种临床场景,并纳入了 MICCAI2023 挑战赛中的 2 个额外数据集来验证不同模型的性能。 该验证集包含七个重要的解剖结构,例如胸部和腹部器官、大脑结构、骨骼等。 它还包括医学领域非常感兴趣的五种病变类型,以及一系列体积测量模式,包括 CT、US(超声)和八个 MRI 序列。 此外,它还包含具有挑战性的、以前未见过的目标,最终形成了不同类别的 153 个不同目标。 验证集有三部分:
图五:在不同的解剖结构中,针对不同数量的点,对SAM、SAM-Med2D和SAM-Med3D进行可视化。作者同时展示了轴切片和冠状切片/矢状切片来全面说明三维结果。
图六:在各种模态下,针对不同的点数,对SAM、SAM-Med2D和SAM-Med3D进行可视化。作者同时展示了轴切片和冠状/矢状切片来全面说明三维结果。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。