当前位置:   article > 正文

RBF多尺度融合卷积,保持尺寸不变,融合各个维度Pytorch_多尺度卷积python

多尺度卷积python
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class BasicConv2d(nn.Module):
  5. """
  6. 这是一个基础的卷积模块,可进行参数设置,膨胀卷积和其他参数
  7. """
  8. def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
  9. super(BasicConv2d, self).__init__()
  10. self.conv = nn.Conv2d(in_planes, out_planes,
  11. kernel_size=kernel_size, stride=stride,
  12. padding=padding, dilation=dilation, bias=False)
  13. self.bn = nn.BatchNorm2d(out_planes)
  14. self.relu = nn.ReLU(inplace=True)
  15. def forward(self, x):
  16. x = self.conv(x)
  17. x = self.bn(x)
  18. return x
  19. class RFB_modified(nn.Module):
  20. def __init__(self, in_channel, out_channel):
  21. super(RFB_modified, self).__init__()
  22. self.relu = nn.ReLU(True)
  23. self.branch0 = nn.Sequential(
  24. BasicConv2d(in_channel, out_channel, 1),
  25. )#通道变换
  26. self.branch1 = nn.Sequential(
  27. BasicConv2d(in_channel, out_channel, 1),
  28. BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
  29. BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
  30. BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
  31. )#第一分支不变尺寸卷积
  32. self.branch2 = nn.Sequential(
  33. BasicConv2d(in_channel, out_channel, 1),
  34. BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
  35. BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
  36. BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
  37. )##第二分支不变尺寸卷积
  38. self.branch3 = nn.Sequential(
  39. BasicConv2d(in_channel, out_channel, 1),
  40. BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
  41. BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
  42. BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
  43. )#第三分支不变尺寸卷积
  44. self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1)
  45. #多尺度拼接以后进行将通道
  46. self.conv_res = BasicConv2d(in_channel, out_channel, 1)
  47. #通道压缩
  48. def forward(self, x):
  49. x0 = self.branch0(x)
  50. x1 = self.branch1(x)
  51. print(x1.shape)
  52. x2 = self.branch2(x)
  53. print(x2.shape)
  54. x3 = self.branch3(x)
  55. print(x3.shape)
  56. x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))#多尺度拼接
  57. print(x_cat.shape)
  58. x = self.relu(x_cat + self.conv_res(x))
  59. return x
  60. if __name__ == '__main__':
  61. ras = RFB_modified(1,1).cuda()
  62. input_tensor = torch.randn(1, 1, 352, 352).cuda()
  63. out = ras(input_tensor)
  64. print(out)
  65. print(out.shape)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/905999
推荐阅读
相关标签
  

闽ICP备14008679号