当前位置:   article > 正文

VM-UNet:视觉Mamba UNet用来医学图像分割 论文及代码解读_vmunet

vmunet


2024.2.27更新代码介绍

期刊分析

期刊名: Arxiv
论文主页: VM-UNet
代码: Code

摘要

医学图像分割领域,基于 CNN 和基于 Transformer 的模型都得到了广泛的探索。然而,CNN 在远程建模能力方面表现出局限性,而 Transformer 则受到其二次计算复杂性的限制。 最近,以 Mamba 为代表的状态空间模型 (SSM) 已成为一种有前途的方法。它们不仅擅长对远程交互进行建模,而且还保持线性计算复杂性。在本文中,利用状态空间模型,我们提出了一种用于医学图像分割的 Ushape 架构模型,名为 Vision Mamba UNet (VM-UNet)。具体来说,引入视觉状态空间(VSS)块作为基础块来捕获广泛的上下文信息,并构建不对称的编码器-解码器结构。我们在 ISIC17、ISIC18 和 Synapse 数据集上进行了全面的实验,结果表明 VM-UNet 在医学图像分割任务中表现出竞争力。据我们所知,这是第一个基于纯SSM模型构建的医学图像分割模型。我们的目标是建立一个基线,并为未来开发更高效、更有效的基于 SSM 的分割系统提供有价值的见解。我们的代码可在 https://github.com/JCruan519/VM-UNet 上获取。


贡献

  1. 提出了 VM-UNet,标志着首次探索纯粹基于 SSM 的模型在医学图像分割中的潜在应用。
  2. 在三个数据集上进行综合实验,结果表明VM-UNet表现出相当的竞争力。
  3. 我们为医学图像分割任务中基于纯SSM的模型建立了基线,提供了宝贵的见解,为开发更高效、更有效的基于SSM的分割方法铺平了道路。

方法

整体框架

整体依旧是传统UNet的结构,将原先的卷积模块换成了视觉状态空间模块(VSS block),将原先的上下采样换成了Patch Merging和Patch Expanding模块,因此跟着代码一个一个模块攻克即可!
在这里插入图片描述

1. Vision Mamba UNet (VM-UNet)

1. Patch Embedding层首先将图像划分为4×4的非重叠patch,然后将图像通道数转化为C
2. 编码器由四个阶段组成,在前三个阶段的末尾应用补丁合并操作,以减少输入特征的高度和宽度,同时增加通道数。
3. 类似地,解码器分为四个阶段。在最后三个阶段的开始,利用补丁扩展操作来减少特征通道的数量并增加高度和宽度。
4. 在解码器之后,采用最终投影层来恢复特征的大小以匹配分割目标。
5. 对于跳跃连接,采用简单的加法操作,没有花哨的东西,因此没有引入任何额外的参数。
疑问: 为什么在选择使用对称性的结构时,在最后一阶段解码器中只使用了1个VSS block呢?

2. VSS block

1. 源自 VMamaba 的 VSS 块是 VM-UNet 的核心模块,如图 1 (b) 所示。经过层归一化后,输入被分为两个分支。在第一个分支中,输入通过线​​性层,然后通过激活函数。在第二个分支中,输入通过线​​性层、深度可分离卷积和激活函数进行处理,然后输入 2D 选择性扫描 (SS2D) 模块以进行进一步的特征提取。随后,使用层归一化对特征进行归一化,然后使用第一个分支的输出执行按元素生成以合并两个路径。最后,使用线性层混合特征,并将此结果与残差连接相结合以形成 VSS 块的输出。本文默认采用SiLU[14]作为激活函数。


在这里插入图片描述
2. SS2D 由三个部分组成:扫描扩展操作(Scan Expanding)、S6 块和扫描合并操作(Scan Merging)这里不要和和Patch merging 和 Patch Expanding 弄混。如图2(a)所示,扫描扩展操作沿四个不同方向(左上到右下、右下到左上、右上到左下、左下到右上)展开输入图像)进入序列。然后,这些序列由 S6 块进行处理以进行特征提取,确保来自各个方向的信息得到彻底扫描,从而捕获不同的特征。随后,如图2(b)所示,扫描合并操作对来自四个方向的序列进行求和并合并,将输出图像恢复到与输入相同的大小。 S6 模块源自 Mamba [16],通过根据输入调整 SSM 的参数,在 S4 [17] 之上引入了一种选择机制。这使得模型能够区分并保留相关信息,同时过滤掉不相关的信息。 S6 块的伪代码如算法 1 所示。
3. SS2D中的核心S6
在这里插入图片描述
首先使用nn.linear()和x生成 Δ \Delta Δ,B,C;然后借助 Δ \Delta Δ,A,B生成 A ˉ \bar{A} Aˉ B ˉ \bar{B} Bˉ,最后借助公式计算出输出值y,具体解析见公式注释。

3. Loss function

在这里插入图片描述

1. VM-UNet的引入旨在验证纯基于SSM的模型在医学图像分割任务中的应用潜力。因此,我们专门利用最基本的二元交叉熵和骰子损失(BceDice 损失)以及交叉熵和骰子损失(CeDice 损失)分别作为二元和多类分割任务的损失函数,如方程 5 所示和 6.
2.
3.

实验

在本节中,我们在 VM-UNet 上针对皮肤病变和器官分割任务进行全面的实验。具体来说,我们评估了 VM-UNet 在 ISIC17、ISIC18 和 Synapse 数据集上的医学图像分割任务上的性能。 对于这两个数据集,我们提供了对多个指标的详细评估,包括并集平均交集 (mIoU)、Dice 相似系数 (DSC)、准确性 (Acc)、灵敏度 (Sen) 和特异性 (Spe)。

继之前的工作[28,9]之后,我们将 ISIC17 和 ISIC18 数据集中的图像大小调整为 256×256,将 Synapse 数据集中的图像大小调整为 224×224。为了防止过度拟合,采用了数据增强技术,包括随机翻转和随机旋转。 ISIC17和ISIC18数据集采用BceDice损失函数,Synapse数据集采用CeDice损失函数。

我们将批量大小设置为 32,并使用 AdamW [23] 优化器,初始学习率为 1e-3。 CosineAnnealingLR [22] 用作调度器,最多 50 次迭代,最小学习率为 1e-5。训练周期设置为 300。对于 VM-UNet,我们使用 VMamba-S [20] 的权重初始化编码器和解码器的权重,VMamba-S [20] 在 ImageNet-1k 上预训练。所有实验均在单个 NVIDIA RTX A6000 GPU 上进行。

1.对比实验

在这里插入图片描述
在这里插入图片描述
将 VM-UNet 与一些最先进的模型进行比较,实验结果如表 1 和表 2 所示。对于 ISIC17 和 ISIC18 数据集,我们的 VM-UNet 在 mIoU、DSC 和 Acc 方面优于其他模型指标。对于Synapse数据集,VM-UNet也取得了有竞争力的性能。例如,我们的模型在 DSC 和 HD95 指标上超过了 Swin-UNet(第一个纯基于 Transformer 的模型)1.95% 和 2.34mm。结果证明了基于SSM的模型在医学图像分割任务中的优越性。


找来了ISIC2018数据集上结果榜单如下:
在这里插入图片描述

VM-UNet 距离SOTA还有一点距离!!

2.消融实验

在这里插入图片描述

在本节中,我们使用 ISIC17 和 ISIC18 数据集对 VMUNet 的初始化进行消融实验。我们分别使用 VMamba-T 和 VMamba-S 的预训练权重初始化 VM-UNet1。实验结果如表 3 所示,表明更有效的预训练权重显着增强了 VM-UNet 的下游性能,表明 VM-UNet很大程度上受预训练权重的影响。

可以很明显看出,大一点的模型预训练权重就是厉害!

总体看下来:

  1. 作为mamba模型最重要的突破点,线性计算复杂度没有做实验证明,真的可惜!
  2. 作者选择在对称结构上使用不对称的VSS block数量,也是经过实验得到的(对称好肯定就用对称的了)。因此感觉将mamba模型放到分割任务中,只超过一些经典模型,对于SOTA模型还有一些距离,因此需要做好攻克的准备。
  3. 作者公开了代码,而且代码结构简单好用,太棒了!!

可借鉴参考

  1. 提供了一个baseline

  2. 提供了mamba模型的介绍流程

  3. 阅读U-Mamba 2024
    Ma, J., Li, F., Wang, B.: U-mamba: Enhancing long-range dependency for biomedical image segmentation. arXiv preprint arXiv:2401.04722 (2024)

  4. 阅读SegMamba 2024
    Xing, Z., Ye, T., Yang, Y., Liu, G., Zhu, L.: Segmamba: Long-range sequential modeling mamba for 3d medical image segmentation. arXiv preprint arXiv:2401.13560 (2024)

代码使用

当前mamba_ssm包只发布了Linux版本,因此在Win系统下无法运行

  1. Linux下安装mamba_ssm

pip install mamba_ssm

  1. 复制代码使用
import time
import math
from functools import partial
from typing import Optional, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange, repeat
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
try:
    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
except:
    pass

# an alternative for mamba_ssm (in which causal_conv1d is needed)
# try:
#     from selective_scan import selective_scan_fn as selective_scan_fn_v1
#     from selective_scan import selective_scan_ref as selective_scan_ref_v1
# except:
#     pass

DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"


def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
    """
    u: r(B D L)
    delta: r(B D L)
    A: r(D N)
    B: r(B N L)
    C: r(B N L)
    D: r(D)
    z: r(B D L)
    delta_bias: r(D), fp32
    
    ignores:
        [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] 
    """
    import numpy as np
    
    # fvcore.nn.jit_handles
    def get_flops_einsum(input_shapes, equation):
        np_arrs = [np.zeros(s) for s in input_shapes]
        optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
        for line in optim.split("\n"):
            if "optimized flop" in line.lower():
                # divided by 2 because we count MAC (multiply-add counted as one flop)
                flop = float(np.floor(float(line.split(":")[-1]) / 2))
                return flop
    

    assert not with_complex

    flops = 0 # below code flops = 0
    if False:
        ...
        """
        dtype_in = u.dtype
        u = u.float()
        delta = delta.float()
        if delta_bias is not None:
            delta = delta + delta_bias[..., None].float()
        if delta_softplus:
            delta = F.softplus(delta)
        batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
        is_variable_B = B.dim() >= 3
        is_variable_C = C.dim() >= 3
        if A.is_complex():
            if is_variable_B:
                B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
            if is_variable_C:
                C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
        else:
            B = B.float()
            C = C.float()
        x = A.new_zeros((batch, dim, dstate))
        ys = []
        """

    flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
    if with_Group:
        flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
    else:
        flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
    if False:
        ...
        """
        deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
        if not is_variable_B:
            deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
        else:
            if B.dim() == 3:
                deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
            else:
                B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
                deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
        if is_variable_C and C.dim() == 4:
            C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
        last_state = None
        """
    
    in_for_flops = B * D * N   
    if with_Group:
        in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
    else:
        in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
    flops += L * in_for_flops 
    if False:
        ...
        """
        for i in range(u.shape[2]):
            x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
            if not is_variable_C:
                y = torch.einsum('bdn,dn->bd', x, C)
            else:
                if C.dim() == 3:
                    y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
                else:
                    y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
            if i == u.shape[2] - 1:
                last_state = x
            if y.is_complex():
                y = y.real * 2
            ys.append(y)
        y = torch.stack(ys, dim=2) # (batch dim L)
        """

    if with_D:
        flops += B * D * L
    if with_Z:
        flops += B * D * L
    if False:
        ...
        """
        out = y if D is None else y + u * rearrange(D, "d -> d 1")
        if z is not None:
            out = out * F.silu(z)
        out = out.to(dtype=dtype_in)
        """
    
    return flops


# 只使用了一次,将数据使用4*4的卷积和LN产生patch,然后转化为b*w*h*c
class PatchEmbed2D(nn.Module):
    r""" Image to Patch Embedding
    Args:
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """
    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
        super().__init__()
        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        x = self.proj(x).permute(0, 2, 3, 1)
        if self.norm is not None:
            x = self.norm(x)
        return x


class PatchMerging2D(nn.Module):
    r""" Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        B, H, W, C = x.shape

        SHAPE_FIX = [-1, -1]
        if (W % 2 != 0) or (H % 2 != 0):
            print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True)
            SHAPE_FIX[0] = H // 2
            SHAPE_FIX[1] = W // 2

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C

        if SHAPE_FIX[0] > 0:
            x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
        
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, H//2, W//2, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x
    

class PatchExpand2D(nn.Module):
    def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim*2
        self.dim_scale = dim_scale
        self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False)
        self.norm = norm_layer(self.dim // dim_scale)

    def forward(self, x):
        B, H, W, C = x.shape
        x = self.expand(x)

        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale)
        x= self.norm(x)

        return x
    

class Final_PatchExpand2D(nn.Module):
    def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim  # 96
        self.dim_scale = dim_scale  # 4
                          #        96             384
        self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False)
        #                          24  
        self.norm = norm_layer(self.dim // dim_scale)

    def forward(self, x):
        B, H, W, C = x.shape
        x = self.expand(x)

        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale)
        x= self.norm(x)

        return x


class SS2D(nn.Module):
    def __init__(
        self,
        d_model,   # 96
        d_state=16,
        d_conv=3,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        dropout=0.,
        conv_bias=True,
        bias=False,
    ):
        super().__init__()
        self.d_model = d_model  # 96
        self.d_state = d_state  # 16
        self.d_conv = d_conv   # 3
        self.expand = expand   # 2
        self.d_inner = int(self.expand * self.d_model)  # 192
        self.dt_rank = math.ceil(self.d_model / 16)  # 6

        #                           96                 384
        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias)
        self.conv2d = nn.Conv2d(
            in_channels=self.d_inner,  # 192
            out_channels=self.d_inner,  # 192
            kernel_size=d_conv,  # 3
            padding=(d_conv - 1) // 2,    # 1
            bias=conv_bias,  
            groups=self.d_inner,  # 192  
        )
        self.act = nn.SiLU()

        self.x_proj = (
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False), 
        )
        # 4*38*192的数据 初始化x的数据
        self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
        del self.x_proj

        # 初始化dt的数据吧
        self.dt_projs = (
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor),
        )
        self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
        self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
        del self.dt_projs
        # 初始化A和D
        self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
        self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)

        # ss2d
        self.forward_core = self.forward_corev0

        self.out_norm = nn.LayerNorm(self.d_inner)
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias)
        self.dropout = nn.Dropout(dropout) if dropout > 0. else None

    @staticmethod
    def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs):
        dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)

        # Initialize special dt projection to preserve variance at initialization
        dt_init_std = dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        dt = torch.exp(
            torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            dt_proj.bias.copy_(inv_dt)
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        dt_proj.bias._no_reinit = True
        
        return dt_proj

    @staticmethod
    def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
        # S4D real initialization
        A = repeat(
            torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=d_inner,
        ).contiguous()
        A_log = torch.log(A)  # Keep A_log in fp32
        if copies > 1:
            A_log = repeat(A_log, "d n -> r d n", r=copies)
            if merge:
                A_log = A_log.flatten(0, 1)
        A_log = nn.Parameter(A_log)
        A_log._no_weight_decay = True
        return A_log

    @staticmethod
    def D_init(d_inner, copies=1, device=None, merge=True):
        # D "skip" parameter
        D = torch.ones(d_inner, device=device)
        if copies > 1:
            D = repeat(D, "n1 -> r n1", r=copies)
            if merge:
                D = D.flatten(0, 1)
        D = nn.Parameter(D)  # Keep in fp32
        D._no_weight_decay = True
        return D

    def forward_corev0(self, x: torch.Tensor):
        self.selective_scan = selective_scan_fn
        
        B, C, H, W = x.shape
        L = H * W
        K = 4

        x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
        xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)

        x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
        dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
        dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)

        xs = xs.float().view(B, -1, L) # (b, k * d, l)
        dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
        Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
        Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
        Ds = self.Ds.float().view(-1) # (k * d)
        As = -torch.exp(self.A_logs.float()).view(-1, self.d_state)  # (k * d, d_state)
        dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)

        out_y = self.selective_scan(
            xs, dts, 
            As, Bs, Cs, Ds, z=None,
            delta_bias=dt_projs_bias,
            delta_softplus=True,
            return_last_state=False,
        ).view(B, K, -1, L)
        assert out_y.dtype == torch.float

        inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
        wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
        invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)

        return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y

    def forward(self, x: torch.Tensor):
        B, H, W, C = x.shape

        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1) # (b, h, w, d)  # x走的是ss2d的路径

        x = x.permute(0, 3, 1, 2).contiguous()
        x = self.act(self.conv2d(x)) # (b, d, h, w)  
        y1, y2, y3, y4 = self.forward_core(x)
        assert y1.dtype == torch.float32
        y = y1 + y2 + y3 + y4
        y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
        y = self.out_norm(y)
        y = y * F.silu(z)   # 这里的z忘记了一个Linear吧
        out = self.out_proj(y)
        if self.dropout is not None:
            out = self.dropout(out)
        return out


class VSSBlock(nn.Module):
    def __init__(
        self,
        hidden_dim: int = 0,  # 96
        drop_path: float = 0,  # 0.2
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),  # nn.LN
        attn_drop_rate: float = 0,  # 0
        d_state: int = 16,
    ):
        super().__init__()
        self.ln_1 = norm_layer(hidden_dim)# 96             0.2                   16
        self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state)
        self.drop_path = DropPath(drop_path)

    def forward(self, input: torch.Tensor):
        # print(input.shape, "传入模块的大小")
        x = input + self.drop_path(self.self_attention(self.ln_1(input)))
        return x


class VSSLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.
    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(  # 以第一个为例
        self, 
        dim,  # # 96
        depth,  # 2
        d_state=16,
        drop = 0.,
        attn_drop=0.,
        drop_path=0.,   # 每一个模块都有一个drop
        norm_layer=nn.LayerNorm, 
        downsample=None,  # PatchMergin2D
        use_checkpoint=False,  
    ):
        super().__init__()
        self.dim = dim
        self.use_checkpoint = use_checkpoint

        self.blocks = nn.ModuleList([
            VSSBlock(
                hidden_dim=dim,   # 96
                drop_path=drop_path[i],  # 0.2
                norm_layer=norm_layer,  # nn.LN
                attn_drop_rate=attn_drop, # 0 
                d_state=d_state,  # 16
            )
            for i in range(depth)])
        
        if True: # is this really applied? Yes, but been overriden later in VSSM!
            def _init_weights(module: nn.Module):
                for name, p in module.named_parameters():
                    if name in ["out_proj.weight"]:
                        p = p.clone().detach_() # fake init, just to keep the seed ....
                        nn.init.kaiming_uniform_(p, a=math.sqrt(5))
            self.apply(_init_weights)

        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
                x = blk(x)
        
        if self.downsample is not None:
            x = self.downsample(x)

        return x
    


class VSSLayer_up(nn.Module):
    """ A basic Swin Transformer layer for one stage.
    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(
        self, 
        dim, 
        depth, 
        attn_drop=0.,
        drop_path=0., 
        norm_layer=nn.LayerNorm, 
        upsample=None, 
        use_checkpoint=False, 
        d_state=16,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.use_checkpoint = use_checkpoint

        self.blocks = nn.ModuleList([
            VSSBlock(
                hidden_dim=dim,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer,
                attn_drop_rate=attn_drop,
                d_state=d_state,
            )
            for i in range(depth)])
        
        if True: # is this really applied? Yes, but been overriden later in VSSM!
            def _init_weights(module: nn.Module):
                for name, p in module.named_parameters():
                    if name in ["out_proj.weight"]:
                        p = p.clone().detach_() # fake init, just to keep the seed ....
                        nn.init.kaiming_uniform_(p, a=math.sqrt(5))
            self.apply(_init_weights)

        if upsample is not None:
            self.upsample = upsample(dim=dim, norm_layer=norm_layer)
        else:
            self.upsample = None


    def forward(self, x):
        if self.upsample is not None:
            x = self.upsample(x)
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        return x
    


class VSSM(nn.Module):
    def __init__(self, patch_size=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], depths_decoder=[2, 9, 2, 2],
                 dims=[96, 192, 384, 768], dims_decoder=[768, 384, 192, 96], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, patch_norm=True,
                 use_checkpoint=False):
        super().__init__()
        self.num_classes = num_classes # 1
        self.num_layers = len(depths)  # 4
        if isinstance(dims, int):
            dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
        self.embed_dim = dims[0]   # 96
        self.num_features = dims[-1]  # 768 
        self.dims = dims   # [96, 192, 384, 768]

        # 4*4+LN-> b*w*h*c
        self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim,
            norm_layer=norm_layer if patch_norm else None)
        
        
        self.pos_drop = nn.Dropout(p=drop_rate)

        # 生成对应的sum(depths)随机深度衰减数值 dpr是正序,dpr_decoder是倒序(用到了[start:end:-1] 反向步长)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_decoder))][::-1]

        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):  # 以第一个为例 num_layers = 4
            layer = VSSLayer(
                dim=dims[i_layer],  # 96
                depth=depths[i_layer],  # 2
                d_state=d_state,  # 16
                drop=drop_rate,   # 0
                attn_drop=attn_drop_rate, # 0
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # ,每一个模块传一个概率值
                norm_layer=norm_layer, # nn.LN
                downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None,
                use_checkpoint=use_checkpoint,
            )
            
            self.layers.append(layer)

        self.layers_up = nn.ModuleList()
        for i_layer in range(self.num_layers):  # 以第一个为例,num_layers=2
            layer = VSSLayer_up(
                dim=dims_decoder[i_layer],  # 768
                depth=depths_decoder[i_layer], # 2
                d_state=d_state,  # 16
                drop=drop_rate,  # 0
                attn_drop=attn_drop_rate, # 0
                drop_path=dpr_decoder[sum(depths_decoder[:i_layer]):sum(depths_decoder[:i_layer + 1])],
                norm_layer=norm_layer, # nn.LN
                upsample=PatchExpand2D if (i_layer != 0) else None,
                use_checkpoint=use_checkpoint,
            )
            self.layers_up.append(layer)

        #  输入 64*64*96 ->linear+LN b*256*256*24                          96                             nn.LN
        self.final_up = Final_PatchExpand2D(dim=dims_decoder[-1], dim_scale=4, norm_layer=norm_layer)
        #     维度变换 输出b*1*256*256         24                 1 
        self.final_conv = nn.Conv2d(dims_decoder[-1]//4, num_classes, 1)
        self.apply(self._init_weights)

    def _init_weights(self, m: nn.Module):
        """
        out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear
        no fc.weight found in the any of the model parameters
        no nn.Embedding found in the any of the model parameters
        so the thing is, VSSBlock initialization is useless
        
        Conv2D is not intialized !!!
        """
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
   
    def forward_features(self, x):
        skip_list = []
        x = self.patch_embed(x)
        x = self.pos_drop(x)

        for layer in self.layers:
            skip_list.append(x)
            x = layer(x)
        return x, skip_list
    
    def forward_features_up(self, x, skip_list):
        for inx, layer_up in enumerate(self.layers_up):
            if inx == 0:
                x = layer_up(x)
            else:
                x = layer_up(x+skip_list[-inx])

        return x
    
    def forward_final(self, x):
        # input 3*64*64*96   out=3 256 256 24
        x = self.final_up(x)   
        x = x.permute(0,3,1,2)
        # out=3 24 256 256 
        x = self.final_conv(x)
        return x

    def forward(self, x):
        x, skip_list = self.forward_features(x)
        x = self.forward_features_up(x, skip_list)
        x = self.forward_final(x)  
        return x


import torch
from torch import nn


class VMUNet(nn.Module):
    def __init__(self, 
                 input_channels=3, 
                 num_classes=1,
                 depths=[2, 2, 9, 2], 
                 depths_decoder=[2, 9, 2, 2],
                 drop_path_rate=0.2,
                 load_ckpt_path=None,
                ):
        super().__init__()

        self.load_ckpt_path = load_ckpt_path
        self.num_classes = num_classes

        self.vmunet = VSSM(in_chans=input_channels, # 3
                           num_classes=num_classes,  # 1
                           depths=depths,  # [2,2,9,2]
                           depths_decoder=depths_decoder, # [2,9,2,2]
                           drop_path_rate=drop_path_rate, # 0.2
                        )
    
    def forward(self, x):
        if x.size()[1] == 1:
            x = x.repeat(1,3,1,1)
        logits = self.vmunet(x)
        if self.num_classes == 1: return torch.sigmoid(logits)
        else: return logits
    
    # 加载预训练权重,暂时用不到
    def load_from(self):
        if self.load_ckpt_path is not None:
            model_dict = self.vmunet.state_dict()
            modelCheckpoint = torch.load(self.load_ckpt_path)
            pretrained_dict = modelCheckpoint['model']
            # 过滤操作
            new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
            model_dict.update(new_dict)
            # 打印出来,更新了多少的参数
            print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict)))
            self.vmunet.load_state_dict(model_dict)

            not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()]
            print('Not loaded keys:', not_loaded_keys)
            print("encoder loaded finished!")

            model_dict = self.vmunet.state_dict()
            modelCheckpoint = torch.load(self.load_ckpt_path)
            pretrained_odict = modelCheckpoint['model']
            pretrained_dict = {}
            for k, v in pretrained_odict.items():
                if 'layers.0' in k: 
                    new_k = k.replace('layers.0', 'layers_up.3')
                    pretrained_dict[new_k] = v
                elif 'layers.1' in k: 
                    new_k = k.replace('layers.1', 'layers_up.2')
                    pretrained_dict[new_k] = v
                elif 'layers.2' in k: 
                    new_k = k.replace('layers.2', 'layers_up.1')
                    pretrained_dict[new_k] = v
                elif 'layers.3' in k: 
                    new_k = k.replace('layers.3', 'layers_up.0')
                    pretrained_dict[new_k] = v
            # 过滤操作
            new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
            model_dict.update(new_dict)
            # 打印出来,更新了多少的参数
            print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict)))
            self.vmunet.load_state_dict(model_dict)
            
            # 找到没有加载的键(keys)
            not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()]
            print('Not loaded keys:', not_loaded_keys)
            print("decoder loaded finished!")

x = torch.randn(3, 3, 256, 256).to("cuda:0")
net = VMUNet(3,1).to("cuda:0")
print(net(x).shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405
  • 406
  • 407
  • 408
  • 409
  • 410
  • 411
  • 412
  • 413
  • 414
  • 415
  • 416
  • 417
  • 418
  • 419
  • 420
  • 421
  • 422
  • 423
  • 424
  • 425
  • 426
  • 427
  • 428
  • 429
  • 430
  • 431
  • 432
  • 433
  • 434
  • 435
  • 436
  • 437
  • 438
  • 439
  • 440
  • 441
  • 442
  • 443
  • 444
  • 445
  • 446
  • 447
  • 448
  • 449
  • 450
  • 451
  • 452
  • 453
  • 454
  • 455
  • 456
  • 457
  • 458
  • 459
  • 460
  • 461
  • 462
  • 463
  • 464
  • 465
  • 466
  • 467
  • 468
  • 469
  • 470
  • 471
  • 472
  • 473
  • 474
  • 475
  • 476
  • 477
  • 478
  • 479
  • 480
  • 481
  • 482
  • 483
  • 484
  • 485
  • 486
  • 487
  • 488
  • 489
  • 490
  • 491
  • 492
  • 493
  • 494
  • 495
  • 496
  • 497
  • 498
  • 499
  • 500
  • 501
  • 502
  • 503
  • 504
  • 505
  • 506
  • 507
  • 508
  • 509
  • 510
  • 511
  • 512
  • 513
  • 514
  • 515
  • 516
  • 517
  • 518
  • 519
  • 520
  • 521
  • 522
  • 523
  • 524
  • 525
  • 526
  • 527
  • 528
  • 529
  • 530
  • 531
  • 532
  • 533
  • 534
  • 535
  • 536
  • 537
  • 538
  • 539
  • 540
  • 541
  • 542
  • 543
  • 544
  • 545
  • 546
  • 547
  • 548
  • 549
  • 550
  • 551
  • 552
  • 553
  • 554
  • 555
  • 556
  • 557
  • 558
  • 559
  • 560
  • 561
  • 562
  • 563
  • 564
  • 565
  • 566
  • 567
  • 568
  • 569
  • 570
  • 571
  • 572
  • 573
  • 574
  • 575
  • 576
  • 577
  • 578
  • 579
  • 580
  • 581
  • 582
  • 583
  • 584
  • 585
  • 586
  • 587
  • 588
  • 589
  • 590
  • 591
  • 592
  • 593
  • 594
  • 595
  • 596
  • 597
  • 598
  • 599
  • 600
  • 601
  • 602
  • 603
  • 604
  • 605
  • 606
  • 607
  • 608
  • 609
  • 610
  • 611
  • 612
  • 613
  • 614
  • 615
  • 616
  • 617
  • 618
  • 619
  • 620
  • 621
  • 622
  • 623
  • 624
  • 625
  • 626
  • 627
  • 628
  • 629
  • 630
  • 631
  • 632
  • 633
  • 634
  • 635
  • 636
  • 637
  • 638
  • 639
  • 640
  • 641
  • 642
  • 643
  • 644
  • 645
  • 646
  • 647
  • 648
  • 649
  • 650
  • 651
  • 652
  • 653
  • 654
  • 655
  • 656
  • 657
  • 658
  • 659
  • 660
  • 661
  • 662
  • 663
  • 664
  • 665
  • 666
  • 667
  • 668
  • 669
  • 670
  • 671
  • 672
  • 673
  • 674
  • 675
  • 676
  • 677
  • 678
  • 679
  • 680
  • 681
  • 682
  • 683
  • 684
  • 685
  • 686
  • 687
  • 688
  • 689
  • 690
  • 691
  • 692
  • 693
  • 694
  • 695
  • 696
  • 697
  • 698
  • 699
  • 700
  • 701
  • 702
  • 703
  • 704
  • 705
  • 706
  • 707
  • 708
  • 709
  • 710
  • 711
  • 712
  • 713
  • 714
  • 715
  • 716
  • 717
  • 718
  • 719
  • 720
  • 721
  • 722
  • 723
  • 724
  • 725
  • 726
  • 727
  • 728
  • 729
  • 730
  • 731
  • 732
  • 733
  • 734
  • 735
  • 736
  • 737
  • 738
  • 739
  • 740
  • 741
  • 742
  • 743
  • 744
  • 745
  • 746
  • 747
  • 748
  • 749
  • 750
  • 751
  • 752
  • 753
  • 754
  • 755
  • 756
  • 757
  • 758
  • 759
  • 760
  • 761
  • 762
  • 763
  • 764
  • 765
  • 766
  • 767
  • 768
  • 769
  • 770
  • 771
  • 772
  • 773
  • 774
  • 775

直接复制使用,相较于原始代码,因为没安装成功selective_scan包,扫描代码发现其中只用到了一处,将其改成了mamba_scan中的函数了,同时将代码中没用到的部件删除,实测效果不错,在自己最近做的任务中,分数可以排进前五!!
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/493000
推荐阅读
相关标签
  

闽ICP备14008679号