当前位置:   article > 正文

yolov7的网络结构代码详解_from utils.autoanchor import check_anchor_order

from utils.autoanchor import check_anchor_order

yolov7的主干网络是其核心创新,主要是增加了E-ELAN模块和aux的辅助检测头,目前来看新的网络结构还是有点效果的,主要的代码详解如下,对应yolo.py。

  1. import argparse
  2. import logging
  3. import sys
  4. from copy import deepcopy
  5. sys.path.append('./') # to run '$ python *.py' files in subdirectories
  6. logger = logging.getLogger(__name__)
  7. import torch
  8. from models.common import *
  9. from models.experimental import *
  10. from utils.autoanchor import check_anchor_order
  11. from utils.general import make_divisible, check_file, set_logging
  12. from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
  13. select_device, copy_attr
  14. from utils.loss import SigmoidBin
  15. try:
  16. import thop # for FLOPS computation
  17. except ImportError:
  18. thop = None
  19. class Detect(nn.Module):
  20. stride = None # strides computed during build
  21. export = False # onnx export
  22. end2end = False
  23. include_nms = False
  24. def __init__(self, nc=80, anchors=(), ch=()): # detection layer
  25. super(Detect, self).__init__()
  26. self.nc = nc # number of classes
  27. self.no = nc + 5 # number of outputs per anchor
  28. self.nl = len(anchors) # number of detection layers
  29. self.na = len(anchors[0]) // 2 # number of anchors
  30. self.grid = [torch.zeros(1)] * self.nl # init grid
  31. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  32. self.register_buffer('anchors', a) # shape(nl,na,2)
  33. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  34. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  35. def forward(self, x):
  36. # x = x.copy() # for profiling
  37. z = [] # inference output
  38. self.training |= self.export
  39. for i in range(self.nl):
  40. x[i] = self.m[i](x[i]) # conv
  41. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  42. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  43. if not self.training: # inference
  44. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  45. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  46. y = x[i].sigmoid()
  47. if not torch.onnx.is_in_onnx_export():
  48. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  49. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  50. else:
  51. xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  52. wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data # wh
  53. y = torch.cat((xy, wh, y[..., 4:]), -1)
  54. z.append(y.view(bs, -1, self.no))
  55. if self.training:
  56. out = x
  57. elif self.end2end:
  58. out = torch.cat(z, 1)
  59. elif self.include_nms:
  60. z = self.convert(z)
  61. out = (z, )
  62. else:
  63. out = (torch.cat(z, 1), x)
  64. return out
  65. @staticmethod
  66. def _make_grid(nx=20, ny=20):
  67. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  68. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  69. def convert(self, z):
  70. z = torch.cat(z, 1)
  71. box = z[:, :, :4]
  72. conf = z[:, :, 4:5]
  73. score = z[:, :, 5:]
  74. score *= conf
  75. convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
  76. dtype=torch.float32,
  77. device=z.device)
  78. box @= convert_matrix
  79. return (box, score)
  80. class IDetect(nn.Module):
  81. stride = None # strides computed during build
  82. export = False # onnx export
  83. end2end = False
  84. include_nms = False
  85. def __init__(self, nc=80, anchors=(), ch=()): # detection layer
  86. super(IDetect, self).__init__()
  87. self.nc = nc # number of classes
  88. self.no = nc + 5 # number of outputs per anchor
  89. self.nl = len(anchors) # number of detection layers=3
  90. self.na = len(anchors[0]) // 2 # number of anchors 6//2=3
  91. self.grid = [torch.zeros(1)] * self.nl # init grid [tensor([0.]), tensor([0.]), tensor([0.])]
  92. a = torch.tensor(anchors).float().view(self.nl, -1, 2) # (3, 6)->(3, 3, 2) = [[[10,13], [16,30], [33,23]], [[30,61], [62,45], [59,119]], [[116,90], [156,198], [373,326]]]
  93. self.register_buffer('anchors', a) # shape(nl,na,2)=(3,3,2)
  94. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)=(3,1,3,1,1,2)
  95. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  96. self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
  97. self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch)
  98. def forward(self, x):
  99. # x = x.copy() # for profiling
  100. z = [] # inference output
  101. self.training |= self.export
  102. for i in range(self.nl):
  103. x[i] = self.m[i](self.ia[i](x[i])) # conv
  104. x[i] = self.im[i](x[i])
  105. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  106. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  107. if not self.training: # inference
  108. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  109. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  110. y = x[i].sigmoid()
  111. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  112. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  113. z.append(y.view(bs, -1, self.no))
  114. return x if self.training else (torch.cat(z, 1), x)
  115. def fuseforward(self, x):
  116. # x = x.copy() # for profiling
  117. z = [] # inference output
  118. self.training |= self.export
  119. for i in range(self.nl):
  120. x[i] = self.m[i](x[i]) # conv
  121. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  122. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  123. if not self.training: # inference
  124. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  125. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  126. y = x[i].sigmoid()
  127. if not torch.onnx.is_in_onnx_export():
  128. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  129. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  130. else:
  131. xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  132. wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data # wh
  133. y = torch.cat((xy, wh, y[..., 4:]), -1)
  134. z.append(y.view(bs, -1, self.no))
  135. if self.training:
  136. out = x
  137. elif self.end2end:
  138. out = torch.cat(z, 1)
  139. elif self.include_nms:
  140. z = self.convert(z)
  141. out = (z, )
  142. else:
  143. out = (torch.cat(z, 1), x)
  144. return out
  145. def fuse(self):
  146. print("IDetect.fuse")
  147. # fuse ImplicitA and Convolution
  148. for i in range(len(self.m)):
  149. c1,c2,_,_ = self.m[i].weight.shape
  150. c1_,c2_, _,_ = self.ia[i].implicit.shape
  151. self.m[i].bias
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/295651
推荐阅读
相关标签
  

闽ICP备14008679号