当前位置:   article > 正文

CANet代码:

canet

函数

 CANet由三部分组成,encoder,co-attention fusion module,decoder。首先看最重要的部分co-attention fusion module代码,该module由PCAM和CCAM模块组成:

 

  1. class PCAM_Module(Module):
  2. """ Position attention module"""
  3. #Ref from SAGAN
  4. def __init__(self, in_dim):
  5. super(PCAM_Module, self).__init__()
  6. self.chanel_in = in_dim
  7. self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
  8. self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
  9. self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
  10. self.gamma = Parameter(torch.zeros(1))
  11. self.softmax = Softmax(dim=-1)
  12. def forward(self, x, y):
  13. """
  14. inputs :
  15. x : input feature maps( B X C X H X W)
  16. returns :
  17. out : attention value + input feature
  18. attention: B X (HxW) X (HxW)
  19. """
  20. m_batchsize, C, height, width = x.size()
  21. # # 生成Q,尺寸变换为(b,c,h,w)->(b,c,w*h)->(b,w*h,c/8)
  22. proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
  23. # 生成K,尺寸变换为(b,c,h,w)->(b,c/8,w*h)
  24. proj_key = self.key_conv(y).view(m_batchsize, -1, width*height)
  25. # q*k,维度变换为(b,w*h,c/8) * (b,c/8,w*h) = (b,w*h,w*h)
  26. energy = torch.bmm(proj_query, proj_key)
  27. # 经过softmax生成注意力图,(b,w*h,w*h)
  28. attention = self.softmax(energy)
  29. # 生成V,维度变换为(b,c,h,w)->(b,c,h*w)
  30. proj_value = self.value_conv(y).view(m_batchsize, -1, width*height)
  31. # attention * V = (b,c,h*w) * (b,w*h,w*h) = (b,c,w*h)
  32. out = torch.bmm(proj_value, attention.permute(0, 2, 1))
  33. # (b,c,w*h)->(b,c,h,w)
  34. out = out.view(m_batchsize, C, height, width)
  35. out = self.gamma*out + x
  36. return out
  37. class CCAM_Module(Module):
  38. """ Channel attention module"""
  39. def __init__(self, in_dim):
  40. super(CCAM_Module, self).__init__()
  41. self.chanel_in = in_dim
  42. self.gamma = Parameter(torch.zeros(1))
  43. self.softmax = Softmax(dim=-1)
  44. def forward(self, x, y):
  45. """
  46. inputs :
  47. x : input feature maps( B X C X H X W)
  48. returns :
  49. out : attention value + input feature
  50. attention: B X C X C
  51. """
  52. m_batchsize, C, height, width = x.size()
  53. # 生成q,(b,c,h,w)->(b,c,n)
  54. proj_query = x.view(m_batchsize, C, -1)
  55. # 生成k,(b,c,h,w)->(b,c,n)->(b,n,c)
  56. proj_key = y.view(m_batchsize, C, -1).permute(0, 2, 1)
  57. # 矩阵相乘,(b,c,n) * (b,n,c) = (b,c,c)
  58. energy = torch.bmm(proj_query, proj_key)
  59. # 生成energy每一行最大的值,以及对应的索引。这里只取值,将其扩充到energy维度减去energy
  60. energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
  61. # 输出注意力map,(b,c,c)
  62. attention = self.softmax(energy_new)
  63. # 生成V,维度为(b,c,h*w)
  64. proj_value = y.view(m_batchsize, C, -1)
  65. # (b,c,c)*(b,c,h*w) = (b,c,h*w)
  66. out = torch.bmm(attention, proj_value)
  67. # (b,c,h*w)->(b,c,h,w)
  68. out = out.view(m_batchsize, C, height, width)
  69. out = self.gamma*out + x
  70. return out

最后输出的两个特征图和卷积输出的特征图共同输入到fusion layer:

  1. class FusionLayer(Module):
  2. def __init__(self, in_channels, groups=1, radix=2, reduction_factor=4, norm_layer=None):
  3. super(FusionLayer, self).__init__()
  4. inter_channels = max(in_channels//reduction_factor, 32) # (256或者32)
  5. self.radix = radix # 2
  6. self.cardinality = groups
  7. self.use_bn = norm_layer is not None
  8. self.relu = ReLU(inplace=True)
  9. self.fc1_p = Conv2d(in_channels, inter_channels, 1, groups=self.cardinality) # 1024 -> 256
  10. self.fc1_c = Conv2d(in_channels, inter_channels, 1, groups=self.cardinality) # 1024 -> 256
  11. if self.use_bn:
  12. self.bn1_p = norm_layer(inter_channels)
  13. self.bn1_c = norm_layer(inter_channels)
  14. self.fc2_p = Conv2d(inter_channels, in_channels*radix, 1, groups=self.cardinality) # 256 -> 1024
  15. self.fc2_c = Conv2d(inter_channels, in_channels*radix, 1, groups=self.cardinality) # 256 -> 1024
  16. self.rsoftmax = rSoftMax(radix, groups)
  17. def forward(self, x, y, z):
  18. """
  19. :param x: convolution fusion features,(b,2048,h,w)
  20. :param y: position attention features,(b,1024,h,w)
  21. :param z: channel attention features,(b,1024,h,w)
  22. :return:
  23. """
  24. assert self.radix == 2, "Error radix size!"
  25. # (b,2048,h,w)
  26. batch, rchannel = x.shape[:2] # n, 2048
  27. if self.radix > 1:
  28. splited = torch.split(x, rchannel//self.radix, dim=1) # 两个,维度分别为(b,1024,h,w)
  29. gap_1 = splited[0] # (b,1024,h,w)
  30. gap_2 = splited[1] # (b,1024,h,w)
  31. else:
  32. gap_1 = x
  33. gap_2 = x
  34. assert gap_1.shape[1] == y.shape[1], "Error!"
  35. assert gap_2.shape[1] == z.shape[1], "Error!"
  36. gap_p = sum([gap_1, y])
  37. gap_c = sum([gap_2, z])
  38. gap_p = F.adaptive_avg_pool2d(gap_p, 1) # n, 1024, h, w -> n, 1024, 1, 1
  39. gap_c = F.adaptive_avg_pool2d(gap_c, 1) # n, 1024, h, w -> n, 1024, 1, 1
  40. gap_p = self.fc1_p(gap_p) # n,256,1,1
  41. gap_c = self.fc1_c(gap_c) # n,256,1,1
  42. if self.use_bn:
  43. gap_p = self.bn1_p(gap_p)
  44. gap_c = self.bn1_c(gap_c)
  45. gap_p = self.relu(gap_p)
  46. gap_c = self.relu(gap_c)
  47. atten_p = self.fc2_p(gap_p) # n, 256, 1, 1 -> n, 2048, 1, 1
  48. atten_c = self.fc2_c(gap_c) # n, 256, 1, 1 -> n, 2048, 1, 1
  49. atten_p = self.rsoftmax(atten_p).view(batch, -1, 1, 1) # (n, 2048) -> (n, 2048, 1, 1)
  50. atten_c = self.rsoftmax(atten_c).view(batch, -1, 1, 1) # (n, 2048) -> (n, 2048, 1, 1)
  51. if self.radix > 1:
  52. attens_p = torch.split(atten_p, rchannel//self.radix, dim=1) # 2(n, 1024, 1, 1) tuple
  53. attens_c = torch.split(atten_c, rchannel//self.radix, dim=1) # 2(n, 1024, 1, 1) tuple
  54. splited_p = (gap_1, y) # ((n, 1024, h, w),(n, 1024, h, w))
  55. splited_c = (gap_1, y) # ((n, 1024, h, w),(n, 1024, h, w))
  56. out_p = sum([att * split for (att, split) in zip(attens_p, splited_p)]) # (n, 1024, h, w)
  57. out_c = sum([att * split for (att, split) in zip(attens_c, splited_c)]) # (n, 1024, h, w)
  58. else:
  59. out_p = atten_p * y
  60. out_c = atten_c * z
  61. if self.radix > 1:
  62. out = torch.cat([out_p, out_c], 1) # (n, 2048, h, w)
  63. else:
  64. out = sum([out_p, out_c])
  65. return out.contiguous()

CANet整体模块,首先需要明确的几点:

1:backbone采用resnet50

2:在decoder采用的TransBasicBlock进行上采样

首先定义一些基本函数,然后对RGB和depth分别进行特征提取:

  1. class ACNet(nn.Module):
  2. def __init__(self, num_class=37, backbone='ResNet-101', pretrained=False, pcca5=False):
  3. super(ACNet, self).__init__()
  4. self.pcca5 = pcca5
  5. self.backbone = backbone
  6. if self.backbone == 'ResNet-50':
  7. layers = [3, 4, 6, 3]
  8. else:
  9. layers = [3, 4, 23, 3]
  10. block = Bottleneck
  11. transblock = TransBasicBlock
  12. # RGB image branch
  13. self.inplanes = 64
  14. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
  15. bias=False)
  16. self.bn1 = nn.BatchNorm2d(64)
  17. self.relu = nn.ReLU(inplace=True)
  18. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  19. self.layer1 = self._make_layer(block, 64, layers[0])
  20. self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  21. self.layer3 = self._make_layer(block, 256, layers[2], stride=2) # use PSPNet extractors
  22. self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  23. # depth image branch
  24. self.inplanes = 64
  25. self.conv1_d = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,
  26. bias=False)
  27. self.bn1_d = nn.BatchNorm2d(64)
  28. self.relu_d = nn.ReLU(inplace=True)
  29. self.maxpool_d = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  30. self.layer1_d = self._make_layer(block, 64, layers[0])
  31. self.layer2_d = self._make_layer(block, 128, layers[1], stride=2)
  32. self.layer3_d = self._make_layer(block, 256, layers[2], stride=2)
  33. self.layer4_d = self._make_layer(block, 512, layers[3], stride=2)
  34. """
  35. # merge branch
  36. self.atten_rgb_0 = self.channel_attention(64)
  37. self.atten_depth_0 = self.channel_attention(64)
  38. self.maxpool_m = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  39. self.atten_rgb_1 = self.channel_attention(64*4)
  40. self.atten_depth_1 = self.channel_attention(64*4)
  41. # self.conv_2 = nn.Conv2d(64*4, 64*4, kernel_size=1) #todo 用cat和conv降回通道数
  42. self.atten_rgb_2 = self.channel_attention(128*4)
  43. self.atten_depth_2 = self.channel_attention(128*4)
  44. self.atten_rgb_3 = self.channel_attention(256*4)
  45. self.atten_depth_3 = self.channel_attention(256*4)
  46. self.atten_rgb_4 = self.channel_attention(512*4)
  47. self.atten_depth_4 = self.channel_attention(512*4)
  48. """
  49. self.inplanes = 64
  50. self.layer1_m = self._make_layer(block, 64, layers[0])
  51. self.layer2_m = self._make_layer(block, 128, layers[1], stride=2)
  52. self.layer3_m = self._make_layer(block, 256, layers[2], stride=2)
  53. self.layer4_m = self._make_layer(block, 512, layers[3], stride=2)
  54. # agant module
  55. self.agant0 = self._make_agant_layer(64, 64)
  56. self.agant1 = self._make_agant_layer(64*4, 64)
  57. self.agant2 = self._make_agant_layer(128*4, 128)
  58. self.agant3 = self._make_agant_layer(256*4, 256)
  59. self.agant4 = self._make_agant_layer(512*4, 512)
  60. #transpose layer
  61. self.inplanes = 512
  62. self.deconv1 = self._make_transpose(transblock, 256, 6, stride=2)
  63. self.deconv2 = self._make_transpose(transblock, 128, 4, stride=2)
  64. self.deconv3 = self._make_transpose(transblock, 64, 3, stride=2)
  65. self.deconv4 = self._make_transpose(transblock, 64, 3, stride=2)
  66. # final blcok
  67. self.inplanes = 64
  68. self.final_conv = self._make_transpose(transblock, 64, 3)
  69. self.final_deconv = nn.ConvTranspose2d(self.inplanes, num_class, kernel_size=2,
  70. stride=2, padding=0, bias=True)
  71. self.out5_conv = nn.Conv2d(256, num_class, kernel_size=1, stride=1, bias=True)
  72. self.out4_conv = nn.Conv2d(128, num_class, kernel_size=1, stride=1, bias=True)
  73. self.out3_conv = nn.Conv2d(64, num_class, kernel_size=1, stride=1, bias=True)
  74. self.out2_conv = nn.Conv2d(64, num_class, kernel_size=1, stride=1, bias=True)
  75. if self.pcca5:
  76. self.conv_5a = nn.Sequential(nn.Conv2d(2048, 512, kernel_size=3, stride=1, padding=1, bias=False),
  77. nn.BatchNorm2d(512),
  78. nn.ReLU())
  79. self.conv_5c = nn.Sequential(nn.Conv2d(2048, 512, kernel_size=3, stride=1, padding=1, bias=False),
  80. nn.BatchNorm2d(512),
  81. nn.ReLU())
  82. self.pca_5 = PCAM_Module(512)
  83. self.cca_5 = CCAM_Module(512)
  84. """
  85. self.pconv_5 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
  86. BatchNorm2d(512),
  87. nn.ReLU())
  88. self.cconv_5 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
  89. BatchNorm2d(512),
  90. nn.ReLU())
  91. self.pconv_out = nn.Sequential(nn.Conv2d(512, 2048, kernel_size=3, stride=1, padding=1, bias=False),
  92. BatchNorm2d(2048),
  93. nn.ReLU(),
  94. nn.Dropout2d(0.1, False))
  95. self.cconv_out = nn.Sequential(nn.Conv2d(512, 2048, kernel_size=3, stride=1, padding=1, bias=False),
  96. BatchNorm2d(2048),
  97. nn.ReLU(),
  98. nn.Dropout2d(0.1, False))
  99. self.alpha = Parameter(torch.ones(1))
  100. self.beta = Parameter(torch.ones(1))
  101. """
  102. self.pconv_5 = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=1, stride=1, padding=0, bias=False),
  103. nn.BatchNorm2d(1024),
  104. nn.ReLU())
  105. self.cconv_5 = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=1, stride=1, padding=0, bias=False),
  106. nn.BatchNorm2d(1024),
  107. nn.ReLU())
  108. self.split_conv = FusionLayer(in_channels=1024, groups=1,radix=2, reduction_factor=4, norm_layer=nn.BatchNorm2d)
  109. # weight initial
  110. for m in self.modules():
  111. if isinstance(m, nn.Conv2d):
  112. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  113. m.weight.data.normal_(0, math.sqrt(2. / n))
  114. elif isinstance(m, nn.BatchNorm2d):
  115. m.weight.data.fill_(1)
  116. m.bias.data.zero_()
  117. if pretrained:
  118. self._load_resnet_pretrained()

其中分别调用了_make_layer函数,block函数,_make_agant_layer函数,_make_transpose函数。

1:_make_layer函数,将输入维度,输出维度,步长,上采样输入到block函数,返回的是一个列表,里面是block个layer。

  1. def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
  2. downsample = None
  3. if stride != 1 or self.inplanes != planes * block.expansion:
  4. downsample = nn.Sequential(
  5. nn.Conv2d(self.inplanes, planes * block.expansion,
  6. kernel_size=1, stride=stride, bias=False),
  7. nn.BatchNorm2d(planes * block.expansion),
  8. )
  9. layers = []
  10. layers.append(block(self.inplanes, planes, stride, downsample))
  11. self.inplanes = planes * block.expansion
  12. for i in range(1, blocks):
  13. layers.append(block(self.inplanes, planes, dilation=dilation))
  14. return nn.Sequential(*layers)

2:block函数,就是一个普通的残差网络,维度由输入的inplane,到输出的inplane*4。

  1. class Bottleneck(nn.Module):
  2. expansion = 4
  3. def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
  4. super(Bottleneck, self).__init__()
  5. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
  6. self.bn1 = nn.BatchNorm2d(planes)
  7. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation,
  8. padding=dilation, bias=False)
  9. self.bn2 = nn.BatchNorm2d(planes)
  10. self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
  11. self.bn3 = nn.BatchNorm2d(planes * 4)
  12. self.relu = nn.ReLU(inplace=True)
  13. self.downsample = downsample
  14. self.stride = stride
  15. def forward(self, x):
  16. residual = x
  17. out = self.conv1(x)
  18. out = self.bn1(out)
  19. out = self.relu(out)
  20. out = self.conv2(out)
  21. out = self.bn2(out)
  22. out = self.relu(out)
  23. out = self.conv3(out)
  24. out = self.bn3(out)
  25. if self.downsample is not None:
  26. residual = self.downsample(x)
  27. out += residual
  28. out = self.relu(out)
  29. return out

3:_make_agant_layer函数,将刚才四倍输出变为原来的维度。

  1. def _make_agant_layer(self, inplanes, planes):
  2. layers = nn.Sequential(
  3. nn.Conv2d(inplanes, planes, kernel_size=1,
  4. stride=1, padding=0, bias=False),
  5. nn.BatchNorm2d(planes),
  6. nn.ReLU(inplace=True)
  7. )
  8. return layers

4:_make_transpose函数。使用nn.ConvTranspose2d进行上采样,将layer放在一起,生成序列。这里的block是TransBasicBlock。

  1. def _make_transpose(self, block, planes, blocks, stride=1):
  2. upsample = None
  3. if stride != 1:
  4. upsample = nn.Sequential(
  5. nn.ConvTranspose2d(self.inplanes, planes,
  6. kernel_size=2, stride=stride,
  7. padding=0, bias=False),
  8. nn.BatchNorm2d(planes),
  9. )
  10. elif self.inplanes != planes:
  11. upsample = nn.Sequential(
  12. nn.Conv2d(self.inplanes, planes,
  13. kernel_size=1, stride=stride, bias=False),
  14. nn.BatchNorm2d(planes),
  15. )
  16. layers = []
  17. for i in range(1, blocks):
  18. layers.append(block(self.inplanes, self.inplanes))
  19. layers.append(block(self.inplanes, planes, stride, upsample))
  20. self.inplanes = planes
  21. return nn.Sequential(*layers)

接着对rgb和depth进行提取:

  1. def encoder(self, rgb, depth):
  2. rgb = self.conv1(rgb)
  3. rgb = self.bn1(rgb)
  4. rgb = self.relu(rgb)
  5. depth = self.conv1_d(depth)
  6. depth = self.bn1_d(depth)
  7. depth = self.relu_d(depth)
  8. m0 = rgb + depth
  9. rgb = self.maxpool(rgb)
  10. depth = self.maxpool_d(depth)
  11. m = self.maxpool(m0)
  12. # block 1
  13. rgb = self.layer1(rgb)
  14. depth = self.layer1_d(depth)
  15. m = self.layer1_m(m)
  16. m1 = m + rgb + depth
  17. # block 2
  18. rgb = self.layer2(rgb)
  19. depth = self.layer2_d(depth)
  20. m = self.layer2_m(m1)
  21. m2 = m + rgb + depth
  22. # block 3
  23. rgb = self.layer3(rgb)
  24. depth = self.layer3_d(depth)
  25. m = self.layer3_m(m2)
  26. m3 = m + rgb + depth
  27. # block 4
  28. rgb = self.layer4(rgb)
  29. depth = self.layer4_d(depth)
  30. m = self.layer4_m(m3)
  31. if self.pcca5:
  32. rgb_down = self.conv_5a(rgb)
  33. depth_down = self.conv_5c(depth)
  34. attention_position = self.pca_5(rgb_down, depth_down)
  35. attention_channel = self.cca_5(rgb_down, depth_down)
  36. p_out = self.pconv_5(attention_position)
  37. c_out = self.cconv_5(attention_channel)
  38. m4 = self.split_conv(m, p_out, c_out)
  39. """
  40. smooth_p = self.pconv_5(attention_position)
  41. smooth_c = self.cconv_5(attention_channel)
  42. p_out = self.pconv_out(smooth_p)
  43. c_out = self.cconv_out(smooth_c)
  44. m4 = m + self.alpha * p_out + self.beta * c_out
  45. """
  46. else:
  47. m4 = m + rgb + depth
  48. return m0, m1, m2, m3, m4 # channel of m is 2048

最后输入进decoder:

  1. def decoder(self, fuse0, fuse1, fuse2, fuse3, fuse4):
  2. agant4 = self.agant4(fuse4)
  3. # upsample 1
  4. x = self.deconv1(agant4)
  5. if self.training:
  6. out5 = self.out5_conv(x)
  7. x = x + self.agant3(fuse3)
  8. # upsample 2
  9. x = self.deconv2(x)
  10. if self.training:
  11. out4 = self.out4_conv(x)
  12. x = x + self.agant2(fuse2)
  13. # upsample 3
  14. x = self.deconv3(x)
  15. if self.training:
  16. out3 = self.out3_conv(x)
  17. x = x + self.agant1(fuse1)
  18. # upsample 4
  19. x = self.deconv4(x)
  20. if self.training:
  21. out2 = self.out2_conv(x)
  22. x = x + self.agant0(fuse0)
  23. # final
  24. x = self.final_conv(x)
  25. out = self.final_deconv(x)
  26. if self.training:
  27. return out, out2, out3, out4, out5
  28. return out

将encoder输出作为decoder输入,整个模型就搭建完毕了。

  1. def forward(self, rgb, depth, phase_checkpoint=False):
  2. fuses = self.encoder(rgb, depth)
  3. m = self.decoder(*fuses)
  4. return m

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/366887
推荐阅读
相关标签
  

闽ICP备14008679号