当前位置:   article > 正文

【论文阅读】RS-Mamba for Large Remote Sensing Image Dense Prediction(附Code)

【论文阅读】RS-Mamba for Large Remote Sensing Image Dense Prediction(附Code)

论文作者提出了RS-Mamba(RSM)用于高分辨率遥感图像遥感的密集预测任务。RSM设计用于模拟具有线性复杂性的遥感图像的全局特征,使其能够有效地处理大型VHR图像。它采用全向选择性扫描模块,从多个方向对图像进行全局建模,从多个方向捕捉大的空间特征。

论文链接:https://arxiv.org/abs/2404.02668

code链接:https://github.com/walking-shadow/Official_Remote_Sensing_Mamba

2D全向扫描机制是本研究的主要创新点。作者考虑到遥感影像地物多方向的特点,在VMamba2D双向扫描机制的基础上增加了斜向扫描机制。

 以下是作者针对该部分进行改进的代码:

  1. def antidiagonal_gather(tensor):
  2. # 取出矩阵所有反斜向的元素并拼接
  3. B, C, H, W = tensor.size()
  4. shift = torch.arange(H, device=tensor.device).unsqueeze(1) # 创建一个列向量[H, 1]
  5. index = (torch.arange(W, device=tensor.device) - shift) % W # 利用广播创建索引矩阵[H, W]
  6. # 扩展索引以适应B和C维度
  7. expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
  8. # 使用gather进行索引选择
  9. return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W)
  10. def diagonal_gather(tensor):
  11. # 取出矩阵所有反斜向的元素并拼接
  12. B, C, H, W = tensor.size()
  13. shift = torch.arange(H, device=tensor.device).unsqueeze(1) # 创建一个列向量[H, 1]
  14. index = (shift + torch.arange(W, device=tensor.device)) % W # 利用广播创建索引矩阵[H, W]
  15. # 扩展索引以适应B和C维度
  16. expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
  17. # 使用gather进行索引选择
  18. return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W)
  19. def diagonal_scatter(tensor_flat, original_shape):
  20. # 把斜向元素拼接起来的一维向量还原为最初的矩阵形式
  21. B, C, H, W = original_shape
  22. shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1) # 创建一个列向量[H, 1]
  23. index = (shift + torch.arange(W, device=tensor_flat.device)) % W # 利用广播创建索引矩阵[H, W]
  24. # 扩展索引以适应B和C维度
  25. expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
  26. # 创建一个空的张量来存储反向散布的结果
  27. result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
  28. # 将平铺的张量重新变形为[B, C, H, W],考虑到需要使用transpose将H和W调换
  29. tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
  30. # 使用scatter_根据expanded_index将元素放回原位
  31. result_tensor.scatter_(3, expanded_index, tensor_reshaped)
  32. return result_tensor
  33. def antidiagonal_scatter(tensor_flat, original_shape):
  34. # 把反斜向元素拼接起来的一维向量还原为最初的矩阵形式
  35. B, C, H, W = original_shape
  36. shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1) # 创建一个列向量[H, 1]
  37. index = (torch.arange(W, device=tensor_flat.device) - shift) % W # 利用广播创建索引矩阵[H, W]
  38. expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
  39. # 初始化一个与原始张量形状相同、元素全为0的张量
  40. result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
  41. # 将平铺的张量重新变形为[B, C, W, H],因为操作是沿最后一个维度收集的,需要调整形状并交换维度
  42. tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
  43. # 使用scatter_将元素根据索引放回原位
  44. result_tensor.scatter_(3, expanded_index, tensor_reshaped)
  45. return result_tensor
  46. class CrossScan(torch.autograd.Function):
  47. # ZSJ 这里是把图像按照特定方向展平的地方,改变扫描方向可以在这里修改
  48. @staticmethod
  49. def forward(ctx, x: torch.Tensor):
  50. B, C, H, W = x.shape
  51. ctx.shape = (B, C, H, W)
  52. # xs = x.new_empty((B, 4, C, H * W))
  53. xs = x.new_empty((B, 8, C, H * W))
  54. # 添加横向和竖向的扫描
  55. xs[:, 0] = x.flatten(2, 3)
  56. xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
  57. xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
  58. # 提供斜向和反斜向的扫描
  59. xs[:, 4] = diagonal_gather(x)
  60. xs[:, 5] = antidiagonal_gather(x)
  61. xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])
  62. return xs
  63. @staticmethod
  64. def backward(ctx, ys: torch.Tensor):
  65. # out: (b, k, d, l)
  66. B, C, H, W = ctx.shape
  67. L = H * W
  68. # 把横向和竖向的反向部分再反向回来,并和原来的横向和竖向相加
  69. # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
  70. y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
  71. # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
  72. # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
  73. y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
  74. y_rb = y_rb.view(B, -1, H, W)
  75. # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
  76. y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, -1, L)
  77. # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
  78. y_da = diagonal_scatter(y_da[:, 0], (B,C,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,C,H,W))
  79. y_res = y_rb + y_da
  80. # return y.view(B, -1, H, W)
  81. return y_res
  82. class CrossMerge(torch.autograd.Function):
  83. @staticmethod
  84. def forward(ctx, ys: torch.Tensor):
  85. B, K, D, H, W = ys.shape
  86. ctx.shape = (H, W)
  87. ys = ys.view(B, K, D, -1)
  88. # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
  89. # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
  90. y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
  91. # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
  92. y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
  93. y_rb = y_rb.view(B, -1, H, W)
  94. # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
  95. y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, D, -1)
  96. # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
  97. y_da = diagonal_scatter(y_da[:, 0], (B,D,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,D,H,W))
  98. y_res = y_rb + y_da
  99. return y_res.view(B, D, -1)
  100. # return y
  101. @staticmethod
  102. def backward(ctx, x: torch.Tensor):
  103. # B, D, L = x.shape
  104. # out: (b, k, d, l)
  105. H, W = ctx.shape
  106. B, C, L = x.shape
  107. # xs = x.new_empty((B, 4, C, L))
  108. xs = x.new_empty((B, 8, C, L))
  109. # 横向和竖向扫描
  110. xs[:, 0] = x
  111. xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
  112. xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
  113. # xs = xs.view(B, 4, C, H, W)
  114. # 提供斜向和反斜向的扫描
  115. xs[:, 4] = diagonal_gather(x.view(B,C,H,W))
  116. xs[:, 5] = antidiagonal_gather(x.view(B,C,H,W))
  117. xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])
  118. # return xs
  119. return xs.view(B, 8, C, H, W)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/444647
推荐阅读
相关标签
  

闽ICP备14008679号