当前位置:   article > 正文

YOLOV5代码yolo.py文件解读

yolo.py

YOLOV5源码的下载:

git clone https://github.com/ultralytics/yolov5.git

YOLOV5代码yolo.py文件解读:

  1. import argparse
  2. import logging
  3. import sys
  4. from copy import deepcopy
  5. from pathlib import Path
  6. import math
  7. sys.path.append('./') # to run '$ python *.py' files in subdirectories
  8. logger = logging.getLogger(__name__)
  9. import torch
  10. import torch.nn as nn
  11. from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape
  12. from models.experimental import MixConv2d, CrossConv, C3
  13. from utils.general import check_anchor_order, make_divisible, check_file, set_logging
  14. from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
  15. select_device, copy_attr
  16. class Detect(nn.Module):
  17. stride = None # strides computed during build
  18. export = False # onnx export
  19. def __init__(self, nc=80, anchors=(), ch=()): # detection layer
  20. super(Detect, self).__init__()
  21. self.nc = nc # number of classes
  22. self.no = nc + 5 # number of outputs per anchor. VOC: 20+5=25
  23. self.nl = len(anchors) # number of detection layers = 3
  24. self.na = len(anchors[0]) // 2 # number of anchors =3
  25. self.grid = [torch.zeros(1)] * self.nl # init grid
  26. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  27. # 模型中需要保存下来的参数包括两种: 一种是反向传播需要被optimizer更新的,称之为 parameter;
  28. # 一种是反向传播不需要被optimizer更新,称之为 buffer。
  29. # 第二种参数我们需要创建tensor, 然后将tensor通过register_buffer()进行注册,
  30. # 可以通过model.buffers() 返回,注册完后参数也会自动保存到OrderDict中去。
  31. # 注意:buffer的更新在forward中,optim.step只能更新nn.parameter类型的参数
  32. self.register_buffer('anchors', a) # shape(nl,na,2)
  33. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  34. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv 1*1卷积
  35. def forward(self, x):
  36. # x = x.copy() # for profiling
  37. z = [] # inference output
  38. self.training |= self.export
  39. for i in range(self.nl):
  40. x[i] = self.m[i](x[i]) # conv
  41. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  42. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  43. if not self.training: # inference
  44. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  45. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  46. y = x[i].sigmoid()
  47. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
  48. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  49. z.append(y.view(bs, -1, self.no)) # 预测框坐标信息
  50. return x if self.training else (torch.cat(z, 1), x) # 预测框坐标, obj, cls
  51. @staticmethod
  52. def _make_grid(nx=20, ny=20):
  53. # 划分为单元网格
  54. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  55. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  56. # 网络模型类
  57. class Model(nn.Module):
  58. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
  59. super(Model, self).__init__()
  60. if isinstance(cfg, dict):
  61. self.yaml = cfg # model dict
  62. else: # is *.yaml
  63. import yaml # for torch hub
  64. self.yaml_file = Path(cfg).name
  65. with open(cfg) as f:
  66. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  67. # Define model
  68. if nc and nc != self.yaml['nc']:
  69. print('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
  70. self.yaml['nc'] = nc # override yaml value
  71. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
  72. # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
  73. # Build strides, anchors
  74. m = self.model[-1] # Detect()
  75. if isinstance(m, Detect):
  76. s = 128 # 2x min stride
  77. # m.stride = [8,16,32]
  78. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  79. # anchor大小计算, 例如 [10, 13] --> [1.25, 1.625]
  80. m.anchors /= m.stride.view(-1, 1, 1)
  81. check_anchor_order(m) # 检查anchor顺序和stride顺序是否一致
  82. self.stride = m.stride
  83. self._initialize_biases() # 初始化偏置 only run once
  84. # print('Strides: %s' % m.stride.tolist())
  85. # Init weights, biases
  86. initialize_weights(self) # 初始化权重
  87. self.info()
  88. print('')
  89. def forward(self, x, augment=False, profile=False):
  90. if augment: # TTA (Test Time Augmentation)
  91. img_size = x.shape[-2:] # height, width
  92. s = [1, 0.83, 0.67] # scales
  93. f = [None, 3, None] # flips (2-ud, 3-lr)
  94. y = [] # outputs
  95. for si, fi in zip(s, f):
  96. xi = scale_img(x.flip(fi) if fi else x, si) # 改变图像尺寸
  97. yi = self.forward_once(xi)[0] # forward
  98. # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  99. yi[..., :4] /= si # de-scale
  100. if fi == 2:
  101. yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
  102. elif fi == 3:
  103. yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
  104. y.append(yi)
  105. return torch.cat(y, 1), None # augmented inference, train
  106. else:
  107. return self.forward_once(x, profile) # single-scale inference, train
  108. def forward_once(self, x, profile=False):
  109. y, dt = [], [] # outputs
  110. for m in self.model:
  111. if m.f != -1: # if not from previous layer
  112. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  113. if profile:
  114. try:
  115. import thop # THOP: PyTorch-OpCounter 估算PyTorch模型的FLOPs模块
  116. o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS
  117. except:
  118. o = 0
  119. t = time_synchronized()
  120. for _ in range(10):
  121. _ = m(x)
  122. dt.append((time_synchronized() - t) * 100)
  123. print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
  124. x = m(x) # 执行网络组件操作
  125. y.append(x if m.i in self.save else None) # save output
  126. if profile:
  127. print('%.1fms total' % sum(dt))
  128. return x
  129. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  130. # https://arxiv.org/abs/1708.02002 section 3.3
  131. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  132. m = self.model[-1] # Detect() module
  133. for mi, s in zip(m.m, m.stride): # from
  134. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  135. b[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  136. b[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  137. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  138. def _print_biases(self):
  139. m = self.model[-1] # Detect() module
  140. for mi in m.m: # from
  141. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  142. print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  143. # def _print_weights(self):
  144. # for m in self.model.modules():
  145. # if type(m) is Bottleneck:
  146. # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  147. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  148. print('Fusing layers... ')
  149. for m in self.model.modules():
  150. if type(m) is Conv and hasattr(m, 'bn'):
  151. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
  152. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  153. delattr(m, 'bn') # remove batchnorm
  154. m.forward = m.fuseforward # update forward
  155. self.info()
  156. return self
  157. def nms(self, mode=True): # add or remove NMS module
  158. present = type(self.model[-1]) is NMS # last layer is NMS
  159. if mode and not present:
  160. print('Adding NMS... ')
  161. m = NMS() # module
  162. m.f = -1 # from
  163. m.i = self.model[-1].i + 1 # index
  164. self.model.add_module(name='%s' % m.i, module=m) # add
  165. self.eval()
  166. elif not mode and present:
  167. print('Removing NMS... ')
  168. self.model = self.model[:-1] # remove
  169. return self
  170. def autoshape(self): # add autoShape module
  171. print('Adding autoShape... ')
  172. m = autoShape(self) # wrap model
  173. copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
  174. return m
  175. def info(self, verbose=False): # print model information
  176. model_info(self, verbose)
  177. # 解析网络模型配置文件并构建模型
  178. def parse_model(d, ch): # model_dict, input_channels(3)
  179. logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  180. #将模型结构的depth_multiple,width_multiple提取出,赋值给gd (yolov5s: 0.33),gw (yolov5s:0.50)
  181. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  182. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors =3
  183. no = na * (nc + 5) # number of outputs = anchors * (classes + 5); VOC : 75
  184. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out = 3
  185. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  186. m = eval(m) if isinstance(m, str) else m # eval strings
  187. for j, a in enumerate(args):
  188. try:
  189. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  190. except:
  191. pass
  192. # 控制深度的代码
  193. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  194. if m in [Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  195. c1, c2 = ch[f], args[0]
  196. # Normal
  197. # if i > 0 and args[0] != no: # channel expansion factor
  198. # ex = 1.75 # exponential (default 2.0)
  199. # e = math.log(c2 / ch[1]) / math.log(2)
  200. # c2 = int(ch[1] * ex ** e)
  201. # if m != Focus:
  202. # 控制宽度(卷积核个数)的代码
  203. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  204. # Experimental
  205. # if i > 0 and args[0] != no: # channel expansion factor
  206. # ex = 1 + gw # exponential (default 2.0)
  207. # ch1 = 32 # ch[1]
  208. # e = math.log(c2 / ch1) / math.log(2) # level 1-n
  209. # c2 = int(ch1 * ex ** e)
  210. # if m != Focus:
  211. # c2 = make_divisible(c2, 8) if c2 != no else c2
  212. args = [c1, c2, *args[1:]]
  213. if m in [BottleneckCSP, C3]:
  214. args.insert(2, n)
  215. n = 1
  216. elif m is nn.BatchNorm2d:
  217. args = [ch[f]]
  218. elif m is Concat:
  219. c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
  220. elif m is Detect:
  221. args.append([ch[x + 1] for x in f])
  222. if isinstance(args[1], int): # number of anchors
  223. args[1] = [list(range(args[1] * 2))] * len(f)
  224. else:
  225. c2 = ch[f]
  226. # *args表示接收任意个数量的参数,调用时会将实际参数打包为一个元组传入实参
  227. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  228. t = str(m)[8:-2].replace('__main__.', '') # module type
  229. np = sum([x.numel() for x in m_.parameters()]) # number params
  230. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  231. logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
  232. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  233. layers.append(m_)
  234. ch.append(c2)
  235. return nn.Sequential(*layers), sorted(save)
  236. if __name__ == '__main__':
  237. # 建立参数解析对象parser
  238. parser = argparse.ArgumentParser()
  239. # 添加属性:给xx实例增加一个aa属性,如 xx.add_argument("aa")
  240. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  241. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  242. # 采用parser对象的parse_args函数获取解析的参数
  243. opt = parser.parse_args()
  244. opt.cfg = check_file(opt.cfg) # check file
  245. set_logging()
  246. device = select_device(opt.device)
  247. # Create model
  248. model = Model(opt.cfg).to(device)
  249. model.train()
  250. # Profile
  251. # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  252. # y = model(img, profile=True)
  253. # Tensorboard
  254. # from torch.utils.tensorboard import SummaryWriter
  255. # tb_writer = SummaryWriter()
  256. # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
  257. # tb_writer.add_graph(model.model, img) # add model to tensorboard
  258. # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/652339
推荐阅读
相关标签
  

闽ICP备14008679号