当前位置:   article > 正文

pytorch实现常用的一些即插即用模块(长期更新)_torch_dwconv

torch_dwconv

1.可分离卷积

  1. #coding:utf-8
  2. import torch.nn as nn
  3. class DWConv(nn.Module):
  4. def __init__(self, in_plane, out_plane):
  5. super(DWConv, self).__init__()
  6. self.depth_conv = nn.Conv2d(in_channels=in_plane,
  7. out_channels=in_plane,
  8. kernel_size=3,
  9. stride=1,
  10. padding=1,
  11. groups=in_plane)
  12. self.point_conv = nn.Conv2d(in_channels=in_plane,
  13. out_channels=out_plane,
  14. kernel_size=1,
  15. stride=1,
  16. padding=0,
  17. groups=1)
  18. def forward(self, x):
  19. x = self.depth_conv(x)
  20. x = self.point_conv(x)
  21. return x
  22. def deubg_dw():
  23. import torch
  24. DW_model = DWConv(3, 32)
  25. x = torch.rand((32, 3, 320, 320))
  26. out = DW_model(x)
  27. print(out.shape)
  28. if __name__ == '__main__':
  29. deubg_dw()

2.DBnet论文中的DBhead

  1. #coding:utf-8
  2. import torch
  3. from torch import nn
  4. class DBHead(nn.Module):
  5. def __init__(self, in_channels, out_channels, k=50):
  6. super().__init__()
  7. self.k = k
  8. self.binarize = nn.Sequential(
  9. nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),
  10. nn.BatchNorm2d(in_channels // 4),
  11. nn.ReLU(inplace=True),
  12. nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
  13. nn.BatchNorm2d(in_channels // 4),
  14. nn.ReLU(inplace=True),
  15. nn.ConvTranspose2d(in_channels // 4, 1, 2, 2),
  16. nn.Sigmoid())
  17. self.binarize.apply(self.weights_init)
  18. self.thresh = self._init_thresh(in_channels)
  19. self.thresh.apply(self.weights_init)
  20. def forward(self, x):
  21. shrink_maps = self.binarize(x)
  22. threshold_maps = self.thresh(x)
  23. if self.training:#从父类继承的变量, train的时候默认是true, eval的时候会变为false
  24. binary_maps = self.step_function(shrink_maps, threshold_maps)
  25. y = torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)
  26. else:
  27. y = torch.cat((shrink_maps, threshold_maps), dim=1)
  28. return y
  29. def weights_init(self, m):
  30. classname = m.__class__.__name__
  31. if classname.find('Conv') != -1:
  32. nn.init.kaiming_normal_(m.weight.data)
  33. elif classname.find('BatchNorm') != -1:
  34. m.weight.data.fill_(1.)
  35. m.bias.data.fill_(1e-4)
  36. def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
  37. in_channels = inner_channels
  38. if serial:
  39. in_channels += 1
  40. self.thresh = nn.Sequential(
  41. nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),
  42. nn.BatchNorm2d(inner_channels // 4),
  43. nn.ReLU(inplace=True),
  44. self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),
  45. nn.BatchNorm2d(inner_channels // 4),
  46. nn.ReLU(inplace=True),
  47. self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
  48. nn.Sigmoid())
  49. return self.thresh
  50. def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
  51. if smooth:
  52. inter_out_channels = out_channels
  53. if out_channels == 1:
  54. inter_out_channels = in_channels
  55. module_list = [
  56. nn.Upsample(scale_factor=2, mode='nearest'),
  57. nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]
  58. if out_channels == 1:
  59. module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))
  60. return nn.Sequential(module_list)
  61. else:
  62. return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
  63. def step_function(self, x, y):
  64. return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
  65. def debug_main():
  66. x = torch.rand((8, 256, 160, 160))
  67. head_model = DBHead(in_channels=256, out_channels=2)
  68. head_model.train()
  69. y = head_model(x)
  70. print('==y.shape:', y.shape)
  71. head_model.eval()
  72. y = head_model(x)
  73. print('==y.shape:', y.shape)
  74. if __name__ == '__main__':
  75. debug_main()

3.sENet中的attention

目的对于不同通道进行加权,先squeeze将h*w*c global averge pooling成1*1*c特征,在经过两层线性层,通过sigmoid输出加权在不同通道。

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class SELayer(nn.Module):
  5. def __init__(self, channel, reduction=16):
  6. super(SELayer, self).__init__()
  7. self.avg_pool = nn.AdaptiveAvgPool2d(1) # 压缩空间
  8. self.fc = nn.Sequential(
  9. nn.Linear(channel, channel // reduction, bias=False),
  10. nn.ReLU(inplace=True),
  11. nn.Linear(channel // reduction, channel, bias=False),
  12. nn.Sigmoid()
  13. )
  14. def forward(self, x):
  15. b, c, _, _ = x.size()
  16. y = self.avg_pool(x).view(b, c)
  17. y = self.fc(y).view(b, c, 1, 1)
  18. return x * y
  19. def debug_attention():
  20. attention_module = SELayer(channel=128, reduction=16)
  21. # B,C,H,W
  22. x = torch.rand((2, 128, 100, 100))
  23. out = attention_module(x)
  24. print('==out.shape:', out.shape)
  25. if __name__ == '__main__':
  26. debug_attention()

4.cv中的self-attention

(1).feature map通过1*1卷积获得,q,k,v三个向量,q与v转置相乘得到attention矩阵,进行softmax归一化到0到1,在作用于V,得到每个像素的加权.

(2).softmax

(3).加权求和

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class Self_Attn(nn.Module):
  5. """ Self attention Layer"""
  6. def __init__(self, in_dim):
  7. super(Self_Attn, self).__init__()
  8. self.chanel_in = in_dim
  9. self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
  10. self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
  11. self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
  12. self.gamma = nn.Parameter(torch.zeros(1))
  13. self.softmax = nn.Softmax(dim=-1)
  14. def forward(self, x):
  15. """
  16. inputs :
  17. x : input feature maps( B * C * W * H)
  18. returns :
  19. out : self attention value + input feature
  20. attention: B * N * N (N is Width*Height)
  21. """
  22. m_batchsize, C, width, height = x.size()
  23. proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B*N*C
  24. proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B*C*N
  25. energy = torch.bmm(proj_query, proj_key) # batch的matmul B*N*N
  26. attention = self.softmax(energy) # B * (N) * (N)
  27. proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B * C * N
  28. out = torch.bmm(proj_value, attention.permute(0, 2, 1)) # B*C*N
  29. out = out.view(m_batchsize, C, width, height) # B*C*H*W
  30. out = self.gamma * out + x
  31. return out, attention
  32. def debug_attention():
  33. attention_module = Self_Attn(in_dim=128)
  34. #B,C,H,W
  35. x = torch.rand((2, 128, 100, 100))
  36. attention_module(x)
  37. if __name__ == '__main__':
  38. debug_attention()

5.spp多窗口pooling

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class SPP(nn.Module):
  5. """
  6. Spatial Pyramid Pooling
  7. """
  8. def __init__(self):
  9. super(SPP, self).__init__()
  10. def forward(self, x):
  11. x_1 = F.max_pool2d(x, kernel_size=5, stride=1, padding=2)
  12. x_2 = F.max_pool2d(x, kernel_size=9, stride=1, padding=4)
  13. x_3 = F.max_pool2d(x, kernel_size=13, stride=1, padding=6)
  14. x = torch.cat([x, x_1, x_2, x_3], dim=1)
  15. return x
  16. def debug_spp():
  17. x = torch.rand((8,3,256,256))
  18. spp = SPP()
  19. x = spp(x)
  20. print('==x.shape:', x.shape)
  21. if __name__ == '__main__':
  22. debug_spp()

6.RetinaFPN

  1. # coding: utf-8
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. class RetinaFPN(nn.Module):
  6. def __init__(self,
  7. C3_inplanes,
  8. C4_inplanes,
  9. C5_inplanes,
  10. planes,
  11. use_p5=False):
  12. super(RetinaFPN, self).__init__()
  13. self.use_p5 = use_p5
  14. self.P3_1 = nn.Conv2d(C3_inplanes,
  15. planes,
  16. kernel_size=1,
  17. stride=1,
  18. padding=0)
  19. self.P3_2 = nn.Conv2d(planes,
  20. planes,
  21. kernel_size=3,
  22. stride=1,
  23. padding=1)
  24. self.P4_1 = nn.Conv2d(C4_inplanes,
  25. planes,
  26. kernel_size=1,
  27. stride=1,
  28. padding=0)
  29. self.P4_2 = nn.Conv2d(planes,
  30. planes,
  31. kernel_size=3,
  32. stride=1,
  33. padding=1)
  34. self.P5_1 = nn.Conv2d(C5_inplanes,
  35. planes,
  36. kernel_size=1,
  37. stride=1,
  38. padding=0)
  39. self.P5_2 = nn.Conv2d(planes,
  40. planes,
  41. kernel_size=3,
  42. stride=1,
  43. padding=1)
  44. if self.use_p5:
  45. self.P6 = nn.Conv2d(planes,
  46. planes,
  47. kernel_size=3,
  48. stride=2,
  49. padding=1)
  50. else:
  51. self.P6 = nn.Conv2d(C5_inplanes,
  52. planes,
  53. kernel_size=3,
  54. stride=2,
  55. padding=1)
  56. self.P7 = nn.Sequential(
  57. nn.ReLU(),
  58. nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1))
  59. def forward(self, inputs):
  60. [C3, C4, C5] = inputs
  61. P5 = self.P5_1(C5)
  62. P4 = self.P4_1(C4)
  63. P4 = F.interpolate(P5, size=(P4.shape[2], P4.shape[3]),
  64. mode='nearest') + P4
  65. P3 = self.P3_1(C3)
  66. P3 = F.interpolate(P4, size=(P3.shape[2], P3.shape[3]),
  67. mode='nearest') + P3
  68. P5 = self.P5_2(P5)
  69. P4 = self.P4_2(P4)
  70. P3 = self.P3_2(P3)
  71. if self.use_p5:
  72. P6 = self.P6(P5)
  73. else:
  74. P6 = self.P6(C5)
  75. del C3, C4, C5
  76. P7 = self.P7(P6)
  77. return [P3, P4, P5, P6, P7]
  78. if __name__ == '__main__':
  79. image_h, image_w = 640, 640
  80. fpn = RetinaFPN(512, 1024, 2048, 256)
  81. C3, C4, C5 = torch.randn(3, 512, 80, 80), torch.randn(3, 1024, 40, 40), torch.randn(3, 2048, 20, 20)
  82. [P3, P4, P5, P6, P7] = fpn([C3, C4, C5])
  83. print("P3", P3.shape)
  84. print("P4", P4.shape)
  85. print("P5", P5.shape)
  86. print("P6", P6.shape)
  87. print("P7", P7.shape)

7.Focus

  1. import torch
  2. import torch.nn as nn
  3. def autopad(k, p=None): # kernel, padding
  4. # Pad to 'same'
  5. if p is None:
  6. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  7. # print('==p:', p)
  8. return p
  9. class Conv(nn.Module):
  10. # Standard convolution
  11. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  12. super(Conv, self).__init__()
  13. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  14. self.bn = nn.BatchNorm2d(c2)
  15. self.act = nn.Hardswish() if act else nn.Identity()
  16. def forward(self, x):
  17. return self.act(self.bn(self.conv(x)))
  18. def fuseforward(self, x):
  19. return self.act(self.conv(x))
  20. class Focus(nn.Module):
  21. # Focus wh information into c-space
  22. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  23. super(Focus, self).__init__()
  24. self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
  25. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  26. return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
  27. def debug_focus():
  28. model = Focus(c1=3, c2=24)
  29. img = torch.rand((8, 3, 124, 124))
  30. print('==img.shape', img.shape)
  31. out = model(img)
  32. print('===out.shape', out.shape)
  33. debug_focus()

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/290957
推荐阅读
相关标签
  

闽ICP备14008679号