赞
踩
2024.2.27更新代码介绍
期刊名:
Arxiv
论文主页:
VM-UNet
代码:
Code
在医学图像分割领域,基于 CNN 和基于 Transformer 的模型都得到了广泛的探索。然而,CNN 在远程建模能力方面表现出局限性,而 Transformer 则受到其二次计算复杂性的限制。 最近,
以 Mamba 为代表的状态空间模型 (SSM) 已成为一种有前途的方法。它们不仅擅长对远程交互进行建模,而且还保持线性计算复杂性。
在本文中,利用状态空间模型,我们提出了一种用于医学图像分割的 Ushape 架构模型,名为 Vision Mamba UNet (VM-UNet)。具体来说,引入视觉状态空间(VSS)块作为基础块来捕获广泛的上下文信息,并构建不对称的编码器-解码器结构。我们在 ISIC17、ISIC18 和 Synapse 数据集上进行了全面的实验,结果表明 VM-UNet 在医学图像分割任务中表现出竞争力。据我们所知,这是第一个基于纯SSM模型构建的医学图像分割模型。我们的目标是建立一个基线,并为未来开发更高效、更有效的基于 SSM 的分割系统提供有价值的见解。我们的代码可在 https://github.com/JCruan519/VM-UNet 上获取。
整体依旧是传统UNet的结构,将原先的卷积模块换成了视觉状态空间模块(VSS block),将原先的上下采样换成了Patch Merging和Patch Expanding模块,因此跟着代码一个一个模块攻克即可!
1.
Patch Embedding层首先将图像划分为4×4的非重叠patch,然后将图像通道数转化为C
2.
编码器由四个阶段组成,在前三个阶段的末尾应用补丁合并操作,以减少输入特征的高度和宽度,同时增加通道数。
3.
类似地,解码器分为四个阶段。在最后三个阶段的开始,利用补丁扩展操作来减少特征通道的数量并增加高度和宽度。
4.
在解码器之后,采用最终投影层来恢复特征的大小以匹配分割目标。
5.
对于跳跃连接,采用简单的加法操作,没有花哨的东西,因此没有引入任何额外的参数。
疑问: 为什么在选择使用对称性的结构时,在最后一阶段解码器中只使用了1个VSS block呢?
1.
源自 VMamaba 的 VSS 块是 VM-UNet 的核心模块,如图 1 (b) 所示。经过层归一化后,输入被分为两个分支。在第一个分支中,输入通过线性层,然后通过激活函数。在第二个分支中,输入通过线性层、深度可分离卷积和激活函数进行处理,然后输入 2D 选择性扫描 (SS2D) 模块以进行进一步的特征提取。随后,使用层归一化对特征进行归一化,然后使用第一个分支的输出执行按元素生成以合并两个路径。最后,使用线性层混合特征,并将此结果与残差连接相结合以形成 VSS 块的输出。本文默认采用SiLU[14]作为激活函数。
2.
SS2D 由三个部分组成:扫描扩展操作(Scan Expanding)、S6 块和扫描合并操作(Scan Merging)这里不要和和Patch merging 和 Patch Expanding 弄混
。如图2(a)所示,扫描扩展操作沿四个不同方向(左上到右下、右下到左上、右上到左下、左下到右上)展开输入图像)进入序列。然后,这些序列由 S6 块进行处理以进行特征提取,确保来自各个方向的信息得到彻底扫描,从而捕获不同的特征。随后,如图2(b)所示,扫描合并操作对来自四个方向的序列进行求和并合并,将输出图像恢复到与输入相同的大小。 S6 模块源自 Mamba [16],通过根据输入调整 SSM 的参数,在 S4 [17] 之上引入了一种选择机制。这使得模型能够区分并保留相关信息,同时过滤掉不相关的信息。 S6 块的伪代码如算法 1 所示。
3. SS2D中的核心S6
首先使用nn.linear()和x生成
Δ
\Delta
Δ,B,C;然后借助
Δ
\Delta
Δ,A,B生成
A
ˉ
\bar{A}
Aˉ,
B
ˉ
\bar{B}
Bˉ,最后借助公式计算出输出值y,具体解析见公式注释。
1.
VM-UNet的引入旨在验证纯基于SSM的模型在医学图像分割任务中的应用潜力。因此,我们专门利用最基本的二元交叉熵和骰子损失(BceDice 损失)以及交叉熵和骰子损失(CeDice 损失)分别作为二元和多类分割任务的损失函数,如方程 5 所示和 6.
2.
3.
在本节中,我们在 VM-UNet 上针对皮肤病变和器官分割任务进行全面的实验。具体来说,我们评估了 VM-UNet 在 ISIC17、ISIC18 和 Synapse 数据集上的医学图像分割任务上的性能。 对于这两个数据集,我们提供了对多个指标的详细评估,包括并集平均交集 (mIoU)、Dice 相似系数 (DSC)、准确性 (Acc)、灵敏度 (Sen) 和特异性 (Spe)。
继之前的工作[28,9]之后,我们将 ISIC17 和 ISIC18 数据集中的图像大小调整为 256×256,将 Synapse 数据集中的图像大小调整为 224×224。为了防止过度拟合,采用了数据增强技术,包括随机翻转和随机旋转。 ISIC17和ISIC18数据集采用BceDice损失函数,Synapse数据集采用CeDice损失函数。
我们将批量大小设置为 32,并使用 AdamW [23] 优化器,初始学习率为 1e-3。 CosineAnnealingLR [22] 用作调度器,最多 50 次迭代,最小学习率为 1e-5。训练周期设置为 300。对于 VM-UNet,我们使用 VMamba-S [20] 的权重初始化编码器和解码器的权重,VMamba-S [20] 在 ImageNet-1k 上预训练。所有实验均在单个 NVIDIA RTX A6000 GPU 上进行。
将 VM-UNet 与一些最先进的模型进行比较,实验结果如表 1 和表 2 所示。对于 ISIC17 和 ISIC18 数据集,我们的 VM-UNet 在 mIoU、DSC 和 Acc 方面优于其他模型指标。对于Synapse数据集,VM-UNet也取得了有竞争力的性能。例如,我们的模型在 DSC 和 HD95 指标上超过了 Swin-UNet(第一个纯基于 Transformer 的模型)1.95% 和 2.34mm。结果证明了基于SSM的模型在医学图像分割任务中的优越性。
找来了ISIC2018数据集上结果榜单如下:
VM-UNet 距离SOTA还有一点距离!!
在本节中,我们使用 ISIC17 和 ISIC18 数据集对 VMUNet 的初始化进行消融实验。我们分别使用 VMamba-T 和 VMamba-S 的预训练权重初始化 VM-UNet1。实验结果如表 3 所示,表明更有效的预训练权重显着增强了 VM-UNet 的下游性能,表明 VM-UNet很大程度上受预训练权重的影响。
可以很明显看出,大一点的模型预训练权重就是厉害!
总体看下来:
- 作为mamba模型最重要的突破点,线性计算复杂度没有做实验证明,真的可惜!
- 作者选择在对称结构上使用不对称的VSS block数量,也是经过实验得到的(对称好肯定就用对称的了)。因此感觉将mamba模型放到分割任务中,只超过一些经典模型,对于SOTA模型还有一些距离,因此需要做好攻克的准备。
- 作者公开了代码,而且代码结构简单好用,太棒了!!
提供了一个baseline
提供了mamba模型的介绍流程
阅读U-Mamba 2024
Ma, J., Li, F., Wang, B.: U-mamba: Enhancing long-range dependency for biomedical image segmentation. arXiv preprint arXiv:2401.04722 (2024)
阅读SegMamba 2024
Xing, Z., Ye, T., Yang, Y., Liu, G., Zhu, L.: Segmamba: Long-range sequential modeling mamba for 3d medical image segmentation. arXiv preprint arXiv:2401.13560 (2024)
当前mamba_ssm包只发布了Linux版本,因此在Win系统下无法运行
pip install mamba_ssm
import time import math from functools import partial from typing import Optional, Callable import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from einops import rearrange, repeat from timm.models.layers import DropPath, to_2tuple, trunc_normal_ try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref except: pass # an alternative for mamba_ssm (in which causal_conv1d is needed) # try: # from selective_scan import selective_scan_fn as selective_scan_fn_v1 # from selective_scan import selective_scan_ref as selective_scan_ref_v1 # except: # pass DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): """ u: r(B D L) delta: r(B D L) A: r(D N) B: r(B N L) C: r(B N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 ignores: [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] """ import numpy as np # fvcore.nn.jit_handles def get_flops_einsum(input_shapes, equation): np_arrs = [np.zeros(s) for s in input_shapes] optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] for line in optim.split("\n"): if "optimized flop" in line.lower(): # divided by 2 because we count MAC (multiply-add counted as one flop) flop = float(np.floor(float(line.split(":")[-1]) / 2)) return flop assert not with_complex flops = 0 # below code flops = 0 if False: ... """ dtype_in = u.dtype u = u.float() delta = delta.float() if delta_bias is not None: delta = delta + delta_bias[..., None].float() if delta_softplus: delta = F.softplus(delta) batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] is_variable_B = B.dim() >= 3 is_variable_C = C.dim() >= 3 if A.is_complex(): if is_variable_B: B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) if is_variable_C: C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) else: B = B.float() C = C.float() x = A.new_zeros((batch, dim, dstate)) ys = [] """ flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") if with_Group: flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") else: flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") if False: ... """ deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) if not is_variable_B: deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) else: if B.dim() == 3: deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) else: B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) if is_variable_C and C.dim() == 4: C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) last_state = None """ in_for_flops = B * D * N if with_Group: in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") else: in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") flops += L * in_for_flops if False: ... """ for i in range(u.shape[2]): x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: y = torch.einsum('bdn,dn->bd', x, C) else: if C.dim() == 3: y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) else: y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) if i == u.shape[2] - 1: last_state = x if y.is_complex(): y = y.real * 2 ys.append(y) y = torch.stack(ys, dim=2) # (batch dim L) """ if with_D: flops += B * D * L if with_Z: flops += B * D * L if False: ... """ out = y if D is None else y + u * rearrange(D, "d -> d 1") if z is not None: out = out * F.silu(z) out = out.to(dtype=dtype_in) """ return flops # 只使用了一次,将数据使用4*4的卷积和LN产生patch,然后转化为b*w*h*c class PatchEmbed2D(nn.Module): r""" Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs): super().__init__() if isinstance(patch_size, int): patch_size = (patch_size, patch_size) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): x = self.proj(x).permute(0, 2, 3, 1) if self.norm is not None: x = self.norm(x) return x class PatchMerging2D(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): B, H, W, C = x.shape SHAPE_FIX = [-1, -1] if (W % 2 != 0) or (H % 2 != 0): print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) SHAPE_FIX[0] = H // 2 SHAPE_FIX[1] = W // 2 x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C if SHAPE_FIX[0] > 0: x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, H//2, W//2, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x class PatchExpand2D(nn.Module): def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim*2 self.dim_scale = dim_scale self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) self.norm = norm_layer(self.dim // dim_scale) def forward(self, x): B, H, W, C = x.shape x = self.expand(x) x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) x= self.norm(x) return x class Final_PatchExpand2D(nn.Module): def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim # 96 self.dim_scale = dim_scale # 4 # 96 384 self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) # 24 self.norm = norm_layer(self.dim // dim_scale) def forward(self, x): B, H, W, C = x.shape x = self.expand(x) x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) x= self.norm(x) return x class SS2D(nn.Module): def __init__( self, d_model, # 96 d_state=16, d_conv=3, expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, dropout=0., conv_bias=True, bias=False, ): super().__init__() self.d_model = d_model # 96 self.d_state = d_state # 16 self.d_conv = d_conv # 3 self.expand = expand # 2 self.d_inner = int(self.expand * self.d_model) # 192 self.dt_rank = math.ceil(self.d_model / 16) # 6 # 96 384 self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias) self.conv2d = nn.Conv2d( in_channels=self.d_inner, # 192 out_channels=self.d_inner, # 192 kernel_size=d_conv, # 3 padding=(d_conv - 1) // 2, # 1 bias=conv_bias, groups=self.d_inner, # 192 ) self.act = nn.SiLU() self.x_proj = ( nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False), nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False), nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False), nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False), ) # 4*38*192的数据 初始化x的数据 self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner) del self.x_proj # 初始化dt的数据吧 self.dt_projs = ( self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor), self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor), self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor), self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor), ) self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank) self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner) del self.dt_projs # 初始化A和D self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N) self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N) # ss2d self.forward_core = self.forward_corev0 self.out_norm = nn.LayerNorm(self.d_inner) self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias) self.dropout = nn.Dropout(dropout) if dropout > 0. else None @staticmethod def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization dt_init_std = dt_rank**-0.5 * dt_scale if dt_init == "constant": nn.init.constant_(dt_proj.weight, dt_init_std) elif dt_init == "random": nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): dt_proj.bias.copy_(inv_dt) # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit dt_proj.bias._no_reinit = True return dt_proj @staticmethod def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): # S4D real initialization A = repeat( torch.arange(1, d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=d_inner, ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 if copies > 1: A_log = repeat(A_log, "d n -> r d n", r=copies) if merge: A_log = A_log.flatten(0, 1) A_log = nn.Parameter(A_log) A_log._no_weight_decay = True return A_log @staticmethod def D_init(d_inner, copies=1, device=None, merge=True): # D "skip" parameter D = torch.ones(d_inner, device=device) if copies > 1: D = repeat(D, "n1 -> r n1", r=copies) if merge: D = D.flatten(0, 1) D = nn.Parameter(D) # Keep in fp32 D._no_weight_decay = True return D def forward_corev0(self, x: torch.Tensor): self.selective_scan = selective_scan_fn B, C, H, W = x.shape L = H * W K = 4 x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) xs = xs.float().view(B, -1, L) # (b, k * d, l) dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) Ds = self.Ds.float().view(-1) # (k * d) As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) out_y = self.selective_scan( xs, dts, As, Bs, Cs, Ds, z=None, delta_bias=dt_projs_bias, delta_softplus=True, return_last_state=False, ).view(B, K, -1, L) assert out_y.dtype == torch.float inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y def forward(self, x: torch.Tensor): B, H, W, C = x.shape xz = self.in_proj(x) x, z = xz.chunk(2, dim=-1) # (b, h, w, d) # x走的是ss2d的路径 x = x.permute(0, 3, 1, 2).contiguous() x = self.act(self.conv2d(x)) # (b, d, h, w) y1, y2, y3, y4 = self.forward_core(x) assert y1.dtype == torch.float32 y = y1 + y2 + y3 + y4 y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) y = self.out_norm(y) y = y * F.silu(z) # 这里的z忘记了一个Linear吧 out = self.out_proj(y) if self.dropout is not None: out = self.dropout(out) return out class VSSBlock(nn.Module): def __init__( self, hidden_dim: int = 0, # 96 drop_path: float = 0, # 0.2 norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), # nn.LN attn_drop_rate: float = 0, # 0 d_state: int = 16, ): super().__init__() self.ln_1 = norm_layer(hidden_dim)# 96 0.2 16 self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state) self.drop_path = DropPath(drop_path) def forward(self, input: torch.Tensor): # print(input.shape, "传入模块的大小") x = input + self.drop_path(self.self_attention(self.ln_1(input))) return x class VSSLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. depth (int): Number of blocks. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( # 以第一个为例 self, dim, # # 96 depth, # 2 d_state=16, drop = 0., attn_drop=0., drop_path=0., # 每一个模块都有一个drop norm_layer=nn.LayerNorm, downsample=None, # PatchMergin2D use_checkpoint=False, ): super().__init__() self.dim = dim self.use_checkpoint = use_checkpoint self.blocks = nn.ModuleList([ VSSBlock( hidden_dim=dim, # 96 drop_path=drop_path[i], # 0.2 norm_layer=norm_layer, # nn.LN attn_drop_rate=attn_drop, # 0 d_state=d_state, # 16 ) for i in range(depth)]) if True: # is this really applied? Yes, but been overriden later in VSSM! def _init_weights(module: nn.Module): for name, p in module.named_parameters(): if name in ["out_proj.weight"]: p = p.clone().detach_() # fake init, just to keep the seed .... nn.init.kaiming_uniform_(p, a=math.sqrt(5)) self.apply(_init_weights) if downsample is not None: self.downsample = downsample(dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): for blk in self.blocks: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x class VSSLayer_up(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. depth (int): Number of blocks. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, dim, depth, attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False, d_state=16, **kwargs, ): super().__init__() self.dim = dim self.use_checkpoint = use_checkpoint self.blocks = nn.ModuleList([ VSSBlock( hidden_dim=dim, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, attn_drop_rate=attn_drop, d_state=d_state, ) for i in range(depth)]) if True: # is this really applied? Yes, but been overriden later in VSSM! def _init_weights(module: nn.Module): for name, p in module.named_parameters(): if name in ["out_proj.weight"]: p = p.clone().detach_() # fake init, just to keep the seed .... nn.init.kaiming_uniform_(p, a=math.sqrt(5)) self.apply(_init_weights) if upsample is not None: self.upsample = upsample(dim=dim, norm_layer=norm_layer) else: self.upsample = None def forward(self, x): if self.upsample is not None: x = self.upsample(x) for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) return x class VSSM(nn.Module): def __init__(self, patch_size=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], depths_decoder=[2, 9, 2, 2], dims=[96, 192, 384, 768], dims_decoder=[768, 384, 192, 96], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, patch_norm=True, use_checkpoint=False): super().__init__() self.num_classes = num_classes # 1 self.num_layers = len(depths) # 4 if isinstance(dims, int): dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] self.embed_dim = dims[0] # 96 self.num_features = dims[-1] # 768 self.dims = dims # [96, 192, 384, 768] # 4*4+LN-> b*w*h*c self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim, norm_layer=norm_layer if patch_norm else None) self.pos_drop = nn.Dropout(p=drop_rate) # 生成对应的sum(depths)随机深度衰减数值 dpr是正序,dpr_decoder是倒序(用到了[start:end:-1] 反向步长) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_decoder))][::-1] self.layers = nn.ModuleList() for i_layer in range(self.num_layers): # 以第一个为例 num_layers = 4 layer = VSSLayer( dim=dims[i_layer], # 96 depth=depths[i_layer], # 2 d_state=d_state, # 16 drop=drop_rate, # 0 attn_drop=attn_drop_rate, # 0 drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # ,每一个模块传一个概率值 norm_layer=norm_layer, # nn.LN downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, ) self.layers.append(layer) self.layers_up = nn.ModuleList() for i_layer in range(self.num_layers): # 以第一个为例,num_layers=2 layer = VSSLayer_up( dim=dims_decoder[i_layer], # 768 depth=depths_decoder[i_layer], # 2 d_state=d_state, # 16 drop=drop_rate, # 0 attn_drop=attn_drop_rate, # 0 drop_path=dpr_decoder[sum(depths_decoder[:i_layer]):sum(depths_decoder[:i_layer + 1])], norm_layer=norm_layer, # nn.LN upsample=PatchExpand2D if (i_layer != 0) else None, use_checkpoint=use_checkpoint, ) self.layers_up.append(layer) # 输入 64*64*96 ->linear+LN b*256*256*24 96 nn.LN self.final_up = Final_PatchExpand2D(dim=dims_decoder[-1], dim_scale=4, norm_layer=norm_layer) # 维度变换 输出b*1*256*256 24 1 self.final_conv = nn.Conv2d(dims_decoder[-1]//4, num_classes, 1) self.apply(self._init_weights) def _init_weights(self, m: nn.Module): """ out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear no fc.weight found in the any of the model parameters no nn.Embedding found in the any of the model parameters so the thing is, VSSBlock initialization is useless Conv2D is not intialized !!! """ if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward_features(self, x): skip_list = [] x = self.patch_embed(x) x = self.pos_drop(x) for layer in self.layers: skip_list.append(x) x = layer(x) return x, skip_list def forward_features_up(self, x, skip_list): for inx, layer_up in enumerate(self.layers_up): if inx == 0: x = layer_up(x) else: x = layer_up(x+skip_list[-inx]) return x def forward_final(self, x): # input 3*64*64*96 out=3 256 256 24 x = self.final_up(x) x = x.permute(0,3,1,2) # out=3 24 256 256 x = self.final_conv(x) return x def forward(self, x): x, skip_list = self.forward_features(x) x = self.forward_features_up(x, skip_list) x = self.forward_final(x) return x import torch from torch import nn class VMUNet(nn.Module): def __init__(self, input_channels=3, num_classes=1, depths=[2, 2, 9, 2], depths_decoder=[2, 9, 2, 2], drop_path_rate=0.2, load_ckpt_path=None, ): super().__init__() self.load_ckpt_path = load_ckpt_path self.num_classes = num_classes self.vmunet = VSSM(in_chans=input_channels, # 3 num_classes=num_classes, # 1 depths=depths, # [2,2,9,2] depths_decoder=depths_decoder, # [2,9,2,2] drop_path_rate=drop_path_rate, # 0.2 ) def forward(self, x): if x.size()[1] == 1: x = x.repeat(1,3,1,1) logits = self.vmunet(x) if self.num_classes == 1: return torch.sigmoid(logits) else: return logits # 加载预训练权重,暂时用不到 def load_from(self): if self.load_ckpt_path is not None: model_dict = self.vmunet.state_dict() modelCheckpoint = torch.load(self.load_ckpt_path) pretrained_dict = modelCheckpoint['model'] # 过滤操作 new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} model_dict.update(new_dict) # 打印出来,更新了多少的参数 print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict))) self.vmunet.load_state_dict(model_dict) not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()] print('Not loaded keys:', not_loaded_keys) print("encoder loaded finished!") model_dict = self.vmunet.state_dict() modelCheckpoint = torch.load(self.load_ckpt_path) pretrained_odict = modelCheckpoint['model'] pretrained_dict = {} for k, v in pretrained_odict.items(): if 'layers.0' in k: new_k = k.replace('layers.0', 'layers_up.3') pretrained_dict[new_k] = v elif 'layers.1' in k: new_k = k.replace('layers.1', 'layers_up.2') pretrained_dict[new_k] = v elif 'layers.2' in k: new_k = k.replace('layers.2', 'layers_up.1') pretrained_dict[new_k] = v elif 'layers.3' in k: new_k = k.replace('layers.3', 'layers_up.0') pretrained_dict[new_k] = v # 过滤操作 new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} model_dict.update(new_dict) # 打印出来,更新了多少的参数 print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict))) self.vmunet.load_state_dict(model_dict) # 找到没有加载的键(keys) not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()] print('Not loaded keys:', not_loaded_keys) print("decoder loaded finished!") x = torch.randn(3, 3, 256, 256).to("cuda:0") net = VMUNet(3,1).to("cuda:0") print(net(x).shape)
直接复制使用,相较于原始代码,因为没安装成功selective_scan包,扫描代码发现其中只用到了一处,将其改成了mamba_scan中的函数了,同时将代码中没用到的部件删除,实测效果不错,在自己最近做的任务中,分数可以排进前五!!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。