当前位置:   article > 正文

多模态融合的高分遥感图像语义分割方法_se-unet

se-unet

多模态融合的高分遥感图像语义分割方法(python)

论文地址:http://www.cnki.com.cn/Article/CJFDTotal-ZNZK202004012.htm

1、SE-UNet 网络模型

在这里插入图片描述

2、SE-UNet的具体设计方案

在这里插入图片描述

3、SE-UNet的pytorch复现

import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from torch.nn import functional as F
import torch

class SEBlock(nn.Module):

    def __init__(self,ch_in):
        super(SEBlock, self).__init__()
        self.relu = nn.ReLU(inplace=False)
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))  # N * 32 * 1 * 1
        self.fc1 = nn.Linear(in_features = int(ch_in), out_features = int(ch_in//2))
        self.fc2 = nn.Linear(in_features = int(ch_in//2), out_features = int(ch_in))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # sequeeze
        out = self.global_pool(x)   
        out = out.view(out.size(0), -1)
        # Excitation
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        out = out.view(out.size(0), out.size(1), 1, 1)
        # Scale
        # out = out * x
        # out += x
        # out = self.relu(out)

        return out
        
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch), #添加了BN层
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)

class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        # 逆卷积,也可以使用上采样(保证k=stride,stride即上采样倍数)
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64, out_ch, 1)
        self.conv1_dilation = nn.Conv2d(2048, 256, 1, stride=1, padding=0, bias=False, dilation=1)  # dilation就是空洞率,即间隔
        self.conv2_dilation = nn.Conv2d(2048, 256, 2, stride=1, padding=2, bias=False, dilation=2)  # dilation就是空洞率,即间隔
        self.conv4_dilation = nn.Conv2d(2048, 256, 4, stride=1, padding=4, bias=False, dilation=4)  # dilation就是空洞率,即间隔
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) 
        self.upsample = nn.Upsample(scale_factor=7, mode='bicubic', align_corners=True) 
        self.conv_c = nn.Conv2d(2816, 1024, 1, stride=1, padding=0, bias=False, dilation=1)  # dilation就是空洞率,即间隔
        self.upsample1 = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=True) 

        self.R1 = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=False)
        ) # N * 16 * 16 * 16

        self.RP2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=False),
            nn.MaxPool2d(2, 2),
            nn.ReLU(inplace=False)
        ) # N * 16 * 16 * 16

        self.RP3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 1, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=False),
            nn.MaxPool2d(2, 2),
            nn.ReLU(inplace=False)
        )

        self.RP4 = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=False),
            nn.MaxPool2d(2, 2),
            nn.ReLU(inplace=False)
        )
        self.RP5 = nn.Sequential(
            nn.Conv2d(512, 1024, 3, 1, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=False),
            nn.MaxPool2d(2, 2),
            nn.ReLU(inplace=False)
        )
        self.SE1 = SEBlock(64)
        self.SE2 = SEBlock(128)
        self.SE3 = SEBlock(256)
        self.SE4 = SEBlock(512)
        self.SE5 = SEBlock(1024)

    def forward(self, DSM, RGB):
        c1_DSM = self.R1(DSM)        # [2, 64, 512, 512]
        c1_SE_DSM = self.SE1(c1_DSM) # [2, 64,  1,  1]
        c1_RGB = self.conv1(RGB)     # [2, 64, 512, 512]
        c1_RGB = c1_SE_DSM * c1_RGB  # [2, 64, 512, 512]
        p1_RGB = self.pool1(c1_RGB)  # [2, 64, 256, 256]

        c2_DSM = self.RP2(c1_DSM)    # [2, 128, 256, 256]
        c2_SE_DSM = self.SE2(c2_DSM) # [2, 128,  1,  1]
        c2_RGB = self.conv2(p1_RGB)  # [2, 128, 256, 256]
        c2_RGB = c2_SE_DSM * c2_RGB  # [2, 128, 256, 256]
        p2_RGB = self.pool2(c2_RGB)  # [2, 128, 128, 128]

        c3_DSM = self.RP3(c2_DSM)    # [2, 256, 128, 128]
        c3_SE_DSM = self.SE3(c3_DSM) # [2, 256,  1,  1]
        c3_RGB = self.conv3(p2_RGB)  # [2, 256, 128, 128]
        c3_RGB = c3_SE_DSM * c3_RGB  # [2, 256, 128, 128]
        p3_RGB = self.pool3(c3_RGB)  # [2, 256, 64, 64]

        c4_DSM = self.RP4(c3_DSM)    # [2, 512, 64, 64]
        c4_SE_DSM = self.SE4(c4_DSM) # [2, 512,  1,  1]
        c4_RGB = self.conv4(p3_RGB)  # [2, 512, 64, 64]
        c4_RGB = c4_SE_DSM * c4_RGB  # [2, 512, 64, 64]
        p4_RGB = self.pool4(c4_RGB)  # [2, 512, 32, 32]

        c5_DSM = self.RP5(c4_DSM)    # [2, 1024, 32, 32]
        c5_SE_DSM = self.SE5(c5_DSM) # [2, 1024,  1,  1]
        c5_RGB = self.conv5(p4_RGB)  # [2, 1024, 32, 32]
        c5_RGB = c5_SE_DSM * c5_RGB  # [2, 1024, 32, 32]
         
        up_6 = self.up6(c5_RGB) # [2, 512, 64, 64]
        merge6 = torch.cat([up_6, c4_RGB], dim=1) # [2, 1024, 64, 64]
        c6 = self.conv6(merge6) # [2, 512, 64, 64]
        up_7 = self.up7(c6)     # [2, 256, 128, 128]

        merge7 = torch.cat([up_7, c3_RGB], dim=1) # [2, 512, 128, 128]
        c7 = self.conv7(merge7) # [2, 256, 128, 128]
        up_8 = self.up8(c7)     # [2, 128, 256, 256]

        merge8 = torch.cat([up_8, c2_RGB], dim=1) # [2, 256, 256, 256]
        c8 = self.conv8(merge8) # [2, 128, 256, 256]
        up_9 = self.up9(c8)     # [2, 64, 512, 512]

        merge9 = torch.cat([up_9, c1_RGB], dim=1) # [2, 128, 512, 512]
        c9 = self.conv9(merge9) # [2, 64, 512, 512]
        c10 = self.conv10(c9)   # [2, 3, 512, 512]
        out = nn.Sigmoid()(c10) # [2, 3, 512, 512]
        return out

if __name__ == "__main__":
    DSM = torch.randn(2, 1, 512, 512)
    RGB = torch.randn(2, 3, 512, 512)
    UNet = Unet(3,3)
    out_result = UNet(DSM,RGB)
    print(out_result)
    print(out_result.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/343208
推荐阅读
相关标签
  

闽ICP备14008679号