当前位置:   article > 正文

YOLOv9大幅度按比例减小模型计算量!加快训练!_github yolov9

github yolov9

 

一、代码及论文链接:

代码链接:GitHub - WongKinYiu/yolov9: Implementation of paper - YOLOv9: Learning What You Want to Learn Using Programmable Gradient Information

论文链接:https://github.com/WongKinYiu/yolov9/tree/main

二、 说明

        本文方法并不能直接替代YOLOv9原作者尚未开源的两个小模型,但可以按比例减小模型尺寸。类似YOLOv5、v8等,可以方便测试YOLOv9在数据集上的性能!方法来源于网络。

三、使用步骤

        参照之前的YOLOv9代码,我们运行yolov9-c.yaml的参数量是239 GLOPs。

        我们将以下代码替换掉YOLOv9工程下models包下yolo.py脚本中的代码。

  1. import argparse
  2. import os
  3. import platform
  4. import sys
  5. from copy import deepcopy
  6. from pathlib import Path
  7. FILE = Path(__file__).resolve()
  8. ROOT = FILE.parents[1] # YOLO root directory
  9. if str(ROOT) not in sys.path:
  10. sys.path.append(str(ROOT)) # add ROOT to PATH
  11. if platform.system() != 'Windows':
  12. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  13. from models.common import *
  14. from models.experimental import *
  15. from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
  16. from utils.plots import feature_visualization
  17. from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
  18. time_sync)
  19. from utils.tal.anchor_generator import make_anchors, dist2bbox
  20. try:
  21. import thop # for FLOPs computation
  22. except ImportError:
  23. thop = None
  24. class Detect(nn.Module):
  25. # YOLO Detect head for detection models
  26. dynamic = False # force grid reconstruction
  27. export = False # export mode
  28. shape = None
  29. anchors = torch.empty(0) # init
  30. strides = torch.empty(0) # init
  31. def __init__(self, nc=80, ch=(), inplace=True): # detection layer
  32. super().__init__()
  33. self.nc = nc # number of classes
  34. self.nl = len(ch) # number of detection layers
  35. self.reg_max = 16
  36. self.no = nc + self.reg_max * 4 # number of outputs per anchor
  37. self.inplace = inplace # use inplace ops (e.g. slice assignment)
  38. self.stride = torch.zeros(self.nl) # strides computed during build
  39. c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
  40. self.cv2 = nn.ModuleList(
  41. nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
  42. self.cv3 = nn.ModuleList(
  43. nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
  44. self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
  45. def forward(self, x):
  46. shape = x[0].shape # BCHW
  47. for i in range(self.nl):
  48. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  49. if self.training:
  50. return x
  51. elif self.dynamic or self.shape != shape:
  52. self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
  53. self.shape = shape
  54. box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
  55. dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
  56. y = torch.cat((dbox, cls.sigmoid()), 1)
  57. return y if self.export else (y, x)
  58. def bias_init(self):
  59. # Initialize Detect() biases, WARNING: requires stride availability
  60. m = self # self.model[-1] # Detect() module
  61. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
  62. # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
  63. for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
  64. a[-1].bias.data[:] = 1.0 # box
  65. b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
  66. class DDetect(nn.Module):
  67. # YOLO Detect head for detection models
  68. dynamic = False # force grid reconstruction
  69. export = False # export mode
  70. shape = None
  71. anchors = torch.empty(0) # init
  72. strides = torch.empty(0) # init
  73. def __init__(self, nc=80, ch=(), inplace=True): # detection layer
  74. super().__init__()
  75. self.nc = nc # number of classes
  76. self.nl = len(ch) # number of detection layers
  77. self.reg_max = 16
  78. self.no = nc + self.reg_max * 4 # number of outputs per anchor
  79. self.inplace = inplace # use inplace ops (e.g. slice assignment)
  80. self.stride = torch.zeros(self.nl) # strides computed during build
  81. c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), max((ch[0], min((self.nc * 2, 128)))) # channels
  82. self.cv2 = nn.ModuleList(
  83. nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4), nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch)
  84. self.cv3 = nn.ModuleList(
  85. nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
  86. self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
  87. def forward(self, x):
  88. shape = x[0].shape # BCHW
  89. for i in range(self.nl):
  90. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  91. if self.training:
  92. return x
  93. elif self.dynamic or self.shape != shape:
  94. self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
  95. self.shape = shape
  96. box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
  97. dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
  98. y = torch.cat((dbox, cls.sigmoid()), 1)
  99. return y if self.export else (y, x)
  100. def bias_init(self):
  101. # Initialize Detect() biases, WARNING: requires stride availability
  102. m = self # self.model[-1] # Detect() module
  103. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
  104. # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
  105. for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
  106. a[-1].bias.data[:] = 1.0 # box
  107. b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
  108. class DualDetect(nn.Module):
  109. # YOLO Detect head for detection models
  110. dynamic = False # force grid reconstruction
  111. export = False # export mode
  112. shape = None
  113. anchors = torch.empty(0) # init
  114. strides = torch.empty(0) # init
  115. def __init__(self, nc=80, ch=(), inplace=True): # detection layer
  116. super().__init__()
  117. self.nc = nc # number of classes
  118. self.nl = len(ch) // 2 # number of detection layers
  119. self.reg_max = 16
  120. self.no = nc + self.reg_max * 4 # number of outputs per anchor
  121. self.inplace = inplace # use inplace ops (e.g. slice assignment)
  122. self.stride = torch.zeros(self.nl) # strides computed during build
  123. c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
  124. c4, c5 = max((ch[self.nl] // 4, self.reg_max * 4, 16)), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
  125. self.cv2 = nn.ModuleList(
  126. nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch[:self.nl])
  127. self.cv3 = nn.ModuleList(
  128. nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
  129. self.cv4 = nn.ModuleList(
  130. nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, 4 * self.reg_max, 1)) for x in ch[self.nl:])
  131. self.cv5 = nn.ModuleList(
  132. nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:])
  133. self.dfl = DFL(self.reg_max)
  134. self.dfl2 = DFL(self.reg_max)
  135. def forward(self, x):
  136. shape = x[0].shape # BCHW
  137. d1 = []
  138. d2 = []
  139. for i in range(self.nl):
  140. d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
  141. d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
  142. if self.training:
  143. return [d1, d2]
  144. elif self.dynamic or self.shape != shape:
  145. self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
  146. self.shape = shape
  147. box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
  148. dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
  149. box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
  150. dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
  151. y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1)]
  152. return y if self.export else (y, [d1, d2])
  153. def bias_init(self):
  154. # Initialize Detect() biases, WARNING: requires stride availability
  155. m = self # self.model[-1] # Detect() module
  156. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
  157. # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
  158. for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
  159. a[-1].bias.data[:] = 1.0 # box
  160. b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
  161. for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
  162. a[-1].bias.data[:] = 1.0 # box
  163. b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
  164. class DualDDetect(nn.Module):
  165. # YOLO Detect head for detection models
  166. dynamic = False # force grid reconstruction
  167. export = False # export mode
  168. shape = None
  169. anchors = torch.empty(0) # init
  170. strides = torch.empty(0) # init
  171. def __init__(self, nc=80, ch=(), inplace=True): # detection layer
  172. super().__init__()
  173. self.nc = nc # number of classes
  174. self.nl = len(ch) // 2 # number of detection layers
  175. self.reg_max = 16
  176. self.no = nc + self.reg_max * 4 # number of outputs per anchor
  177. self.inplace = inplace # use inplace ops (e.g. slice assignment)
  178. self.stride = torch.zeros(self.nl) # strides computed during build
  179. c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), max((ch[0], min((self.nc * 2, 128)))) # channels
  180. c4, c5 = make_divisible(max((ch[self.nl] // 4, self.reg_max * 4, 16)), 4), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
  181. self.cv2 = nn.ModuleList(
  182. nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4), nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch[:self.nl])
  183. self.cv3 = nn.ModuleList(
  184. nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
  185. self.cv4 = nn.ModuleList(
  186. nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3, g=4), nn.Conv2d(c4, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl:])
  187. self.cv5 = nn.ModuleList(
  188. nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:])
  189. self.dfl = DFL(self.reg_max)
  190. self.dfl2 = DFL(self.reg_max)
  191. def forward(self, x):
  192. shape = x[0].shape # BCHW
  193. d1 = []
  194. d2 = []
  195. for i in range(self.nl):
  196. d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
  197. d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
  198. if self.training:
  199. return [d1, d2]
  200. elif self.dynamic or self.shape != shape:
  201. self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
  202. self.shape = shape
  203. box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
  204. dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
  205. box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
  206. dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
  207. y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1)]
  208. return y if self.export else (y, [d1, d2])
  209. #y = torch.cat((dbox2, cls2.sigmoid()), 1)
  210. #return y if self.export else (y, d2)
  211. #y1 = torch.cat((dbox, cls.sigmoid()), 1)
  212. #y2 = torch.cat((dbox2, cls2.sigmoid()), 1)
  213. #return [y1, y2] if self.export else [(y1, d1), (y2, d2)]
  214. #return [y1, y2] if self.export else [(y1, y2), (d1, d2)]
  215. def bias_init(self):
  216. # Initialize Detect() biases, WARNING: requires stride availability
  217. m = self # self.model[-1] # Detect() module
  218. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
  219. # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
  220. for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
  221. a[-1].bias.data[:] = 1.0 # box
  222. b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
  223. for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
  224. a[-1].bias.data[:] = 1.0 # box
  225. b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
  226. class TripleDetect(nn.Module):
  227. # YOLO Detect head for detection models
  228. dynamic = False # force grid reconstruction
  229. export = False # export mode
  230. shape = None
  231. anchors = torch.empty(0) # init
  232. strides = torch.empty(0) # init
  233. def __init__(self, nc=80, ch=(), inplace=True): # detection layer
  234. super().__init__()
  235. self.nc = nc # number of classes
  236. self.nl = len(ch) // 3 # number of detection layers
  237. self.reg_max = 16
  238. self.no = nc + self.reg_max * 4 # number of outputs per anchor
  239. self.inplace = inplace # use inplace ops (e.g. slice assignment)
  240. self.stride = torch.zeros(self.nl) # strides computed during build
  241. c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
  242. c4, c5 = max((ch[self.nl] // 4, self.reg_max * 4, 16)), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
  243. c6, c7 = max((ch[self.nl * 2] // 4, self.reg_max * 4, 16)), max((ch[self.nl * 2], min((self.nc * 2, 128)))) # channels
  244. self.cv2 = nn.ModuleList(
  245. nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch[:self.nl])
  246. self.cv3 = nn.ModuleList(
  247. nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
  248. self.cv4 = nn.ModuleList(
  249. nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, 4 * self.reg_max, 1)) for x in ch[self.nl:self.nl*2])
  250. self.cv5 = nn.ModuleList(
  251. nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:self.nl*2])
  252. self.cv6 = nn.ModuleList(
  253. nn.Sequential(Conv(x, c6, 3), Conv(c6, c6, 3), nn.Conv2d(c6, 4 * self.reg_max, 1)) for x in ch[self.nl*2:self.nl*3])
  254. self.cv7 = nn.ModuleList(
  255. nn.Sequential(Conv(x, c7, 3), Conv(c7, c7, 3), nn.Conv2d(c7, self.nc, 1)) for x in ch[self.nl*2:self.nl*3])
  256. self.dfl = DFL(self.reg_max)
  257. self.dfl2 = DFL(self.reg_max)
  258. self.dfl3 = DFL(self.reg_max)
  259. def forward(self, x):
  260. shape = x[0].shape # BCHW
  261. d1 = []
  262. d2 = []
  263. d3 = []
  264. for i in range(self.nl):
  265. d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
  266. d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
  267. d3.append(torch.cat((self.cv6[i](x[self.nl*2+i]), self.cv7[i](x[self.nl*2+i])), 1))
  268. if self.training:
  269. return [d1, d2, d3]
  270. elif self.dynamic or self.shape != shape:
  271. self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
  272. self.shape = shape
  273. box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
  274. dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
  275. box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
  276. dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
  277. box3, cls3 = torch.cat([di.view(shape[0], self.no, -1) for di in d3], 2).split((self.reg_max * 4, self.nc), 1)
  278. dbox3 = dist2bbox(self.dfl3(box3), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
  279. y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1), torch.cat((dbox3, cls3.sigmoid()), 1)]
  280. return y if self.export else (y, [d1, d2, d3])
  281. def bias_init(self):
  282. # Initialize Detect() biases, WARNING: requires stride availability
  283. m = self # self.model[-1] # Detect() module
  284. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
  285. # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
  286. for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
  287. a[-1].bias.data[:] = 1.0 # box
  288. b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
  289. for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
  290. a[-1].bias.data[:] = 1.0 # box
  291. b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
  292. for a, b, s in zip(m.cv6, m.cv7, m.stride): # from
  293. a[-1].bias.data[:] = 1.0 # box
  294. b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
  295. class TripleDDetect(nn.Module):
  296. # YOLO Detect head for detection models
  297. dynamic = False # force grid reconstruction
  298. export = False # export mode
  299. shape = None
  300. anchors = torch.empty(0) # init
  301. strides = torch.empty(0) # init
  302. def __init__(self, nc=80, ch=(), inplace=True): # detection layer
  303. super().__init__()
  304. self.nc = nc # number of classes
  305. self.nl = len(ch) // 3 # number of detection layers
  306. self.reg_max = 16
  307. self.no = nc + self.reg_max * 4 # number of outputs per anchor
  308. self.inplace = inplace # use inplace ops (e.g. slice assignment)
  309. self.stride = torch.zeros(self.nl) # strides computed during build
  310. c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), \
  311. max((ch[0], min((self.nc * 2, 128)))) # channels
  312. c4, c5 = make_divisible(max((ch[self.nl] // 4, self.reg_max * 4, 16)), 4), \
  313. max((ch[self.nl], min((self.nc * 2, 128)))) # channels
  314. c6, c7 = make_divisible(max((ch[self.nl * 2] // 4, self.reg_max * 4, 16)), 4), \
  315. max((ch[self.nl * 2], min((self.nc * 2, 128)))) # channels
  316. self.cv2 = nn.ModuleList(
  317. nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4),
  318. nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch[:self.nl])
  319. self.cv3 = nn.ModuleList(
  320. nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
  321. self.cv4 = nn.ModuleList(
  322. nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3, g=4),
  323. nn.Conv2d(c4, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl:self.nl*2])
  324. self.cv5 = nn.ModuleList(
  325. nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:self.nl*2])
  326. self.cv6 = nn.ModuleList(
  327. nn.Sequential(Conv(x, c6, 3), Conv(c6, c6, 3, g=4),
  328. nn.Conv2d(c6, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl*2:self.nl*3])
  329. self.cv7 = nn.ModuleList(
  330. nn.Sequential(Conv(x, c7, 3), Conv(c7, c7, 3), nn.Conv2d(c7, self.nc, 1)) for x in ch[self.nl*2:self.nl*3])
  331. self.dfl = DFL(self.reg_max)
  332. self.dfl2 = DFL(self.reg_max)
  333. self.dfl3 = DFL(self.reg_max)
  334. def forward(self, x):
  335. shape = x[0].shape # BCHW
  336. d1 = []
  337. d2 = []
  338. d3 = []
  339. for i in range(self.nl):
  340. d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
  341. d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
  342. d3.append(torch.cat((self.cv6[i](x[self.nl*2+i]), self.cv7[i](x[self.nl*2+i])), 1))
  343. if self.training:
  344. return [d1, d2, d3]
  345. elif self.dynamic or self.shape != shape:
  346. self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
  347. self.shape = shape
  348. box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
  349. dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
  350. box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
  351. dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
  352. box3, cls3 = torch.cat([di.view(shape[0], self.no, -1) for di in d3], 2).split((self.reg_max * 4, self.nc), 1)
  353. dbox3 = dist2bbox(self.dfl3(box3), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
  354. #y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1), torch.cat((dbox3, cls3.sigmoid()), 1)]
  355. #return y if self.export else (y, [d1, d2, d3])
  356. y = torch.cat((dbox3, cls3.sigmoid()), 1)
  357. return y if self.export else (y, d3)
  358. def bias_init(self):
  359. # Initialize Detect() biases, WARNING: requires stride availability
  360. m = self # self.model[-1] # Detect() module
  361. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
  362. # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
  363. for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
  364. a[-1].bias.data[:] = 1.0 # box
  365. b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
  366. for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
  367. a[-1].bias.data[:] = 1.0 # box
  368. b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
  369. for a, b, s in zip(m.cv6, m.cv7, m.stride): # from
  370. a[-1].bias.data[:] = 1.0 # box
  371. b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
  372. class Segment(Detect):
  373. # YOLO Segment head for segmentation models
  374. def __init__(self, nc=80, nm=32, npr=256, ch=(), inplace=True):
  375. super().__init__(nc, ch, inplace)
  376. self.nm = nm # number of masks
  377. self.npr = npr # number of protos
  378. self.proto = Proto(ch[0], self.npr, self.nm) # protos
  379. self.detect = Detect.forward
  380. c4 = max(ch[0] // 4, self.nm)
  381. self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
  382. def forward(self, x):
  383. p = self.proto(x[0])
  384. bs = p.shape[0]
  385. mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
  386. x = self.detect(self, x)
  387. if self.training:
  388. return x, mc, p
  389. return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
  390. class Panoptic(Detect):
  391. # YOLO Panoptic head for panoptic segmentation models
  392. def __init__(self, nc=80, sem_nc=93, nm=32, npr=256, ch=(), inplace=True):
  393. super().__init__(nc, ch, inplace)
  394. self.sem_nc = sem_nc
  395. self.nm = nm # number of masks
  396. self.npr = npr # number of protos
  397. self.proto = Proto(ch[0], self.npr, self.nm) # protos
  398. self.uconv = UConv(ch[0], ch[0]//4, self.sem_nc+self.nc)
  399. self.detect = Detect.forward
  400. c4 = max(ch[0] // 4, self.nm)
  401. self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
  402. def forward(self, x):
  403. p = self.proto(x[0])
  404. s = self.uconv(x[0])
  405. bs = p.shape[0]
  406. mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
  407. x = self.detect(self, x)
  408. if self.training:
  409. return x, mc, p, s
  410. return (torch.cat([x, mc], 1), p, s) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p, s))
  411. class BaseModel(nn.Module):
  412. # YOLO base model
  413. def forward(self, x, profile=False, visualize=False):
  414. return self._forward_once(x, profile, visualize) # single-scale inference, train
  415. def _forward_once(self, x, profile=False, visualize=False):
  416. y, dt = [], [] # outputs
  417. for m in self.model:
  418. if m.f != -1: # if not from previous layer
  419. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  420. if profile:
  421. self._profile_one_layer(m, x, dt)
  422. x = m(x) # run
  423. y.append(x if m.i in self.save else None) # save output
  424. if visualize:
  425. feature_visualization(x, m.type, m.i, save_dir=visualize)
  426. return x
  427. def _profile_one_layer(self, m, x, dt):
  428. c = m == self.model[-1] # is final layer, copy input as inplace fix
  429. o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
  430. t = time_sync()
  431. for _ in range(10):
  432. m(x.copy() if c else x)
  433. dt.append((time_sync() - t) * 100)
  434. if m == self.model[0]:
  435. LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
  436. LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
  437. if c:
  438. LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
  439. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  440. LOGGER.info('Fusing layers... ')
  441. for m in self.model.modules():
  442. if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
  443. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  444. delattr(m, 'bn') # remove batchnorm
  445. m.forward = m.forward_fuse # update forward
  446. self.info()
  447. return self
  448. def info(self, verbose=False, img_size=640): # print model information
  449. model_info(self, verbose, img_size)
  450. def _apply(self, fn):
  451. # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
  452. self = super()._apply(fn)
  453. m = self.model[-1] # Detect()
  454. if isinstance(m, (Detect, DualDetect, TripleDetect, DDetect, DualDDetect, TripleDDetect, Segment)):
  455. m.stride = fn(m.stride)
  456. m.anchors = fn(m.anchors)
  457. m.strides = fn(m.strides)
  458. # m.grid = list(map(fn, m.grid))
  459. return self
  460. class DetectionModel(BaseModel):
  461. # YOLO detection model
  462. def __init__(self, cfg='yolo.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
  463. super().__init__()
  464. if isinstance(cfg, dict):
  465. self.yaml = cfg # model dict
  466. else: # is *.yaml
  467. import yaml # for torch hub
  468. self.yaml_file = Path(cfg).name
  469. with open(cfg, encoding='ascii', errors='ignore') as f:
  470. self.yaml = yaml.safe_load(f) # model dict
  471. # Define model
  472. ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
  473. if nc and nc != self.yaml['nc']:
  474. LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
  475. self.yaml['nc'] = nc # override yaml value
  476. if anchors:
  477. LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
  478. self.yaml['anchors'] = round(anchors) # override yaml value
  479. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
  480. self.names = [str(i) for i in range(self.yaml['nc'])] # default names
  481. self.inplace = self.yaml.get('inplace', True)
  482. # Build strides, anchors
  483. m = self.model[-1] # Detect()
  484. if isinstance(m, (Detect, DDetect, Segment)):
  485. s = 256 # 2x min stride
  486. m.inplace = self.inplace
  487. forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment)) else self.forward(x)
  488. m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
  489. # check_anchor_order(m)
  490. # m.anchors /= m.stride.view(-1, 1, 1)
  491. self.stride = m.stride
  492. m.bias_init() # only run once
  493. if isinstance(m, (DualDetect, TripleDetect, DualDDetect, TripleDDetect)):
  494. s = 256 # 2x min stride
  495. m.inplace = self.inplace
  496. #forward = lambda x: self.forward(x)[0][0] if isinstance(m, (DualSegment)) else self.forward(x)[0]
  497. forward = lambda x: self.forward(x)[0]
  498. m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
  499. # check_anchor_order(m)
  500. # m.anchors /= m.stride.view(-1, 1, 1)
  501. self.stride = m.stride
  502. m.bias_init() # only run once
  503. # Init weights, biases
  504. initialize_weights(self)
  505. self.info()
  506. LOGGER.info('')
  507. def forward(self, x, augment=False, profile=False, visualize=False):
  508. if augment:
  509. return self._forward_augment(x) # augmented inference, None
  510. return self._forward_once(x, profile, visualize) # single-scale inference, train
  511. def _forward_augment(self, x):
  512. img_size = x.shape[-2:] # height, width
  513. s = [1, 0.83, 0.67] # scales
  514. f = [None, 3, None] # flips (2-ud, 3-lr)
  515. y = [] # outputs
  516. for si, fi in zip(s, f):
  517. xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
  518. yi = self._forward_once(xi)[0] # forward
  519. # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  520. yi = self._descale_pred(yi, fi, si, img_size)
  521. y.append(yi)
  522. y = self._clip_augmented(y) # clip augmented tails
  523. return torch.cat(y, 1), None # augmented inference, train
  524. def _descale_pred(self, p, flips, scale, img_size):
  525. # de-scale predictions following augmented inference (inverse operation)
  526. if self.inplace:
  527. p[..., :4] /= scale # de-scale
  528. if flips == 2:
  529. p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
  530. elif flips == 3:
  531. p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
  532. else:
  533. x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
  534. if flips == 2:
  535. y = img_size[0] - y # de-flip ud
  536. elif flips == 3:
  537. x = img_size[1] - x # de-flip lr
  538. p = torch.cat((x, y, wh, p[..., 4:]), -1)
  539. return p
  540. def _clip_augmented(self, y):
  541. # Clip YOLO augmented inference tails
  542. nl = self.model[-1].nl # number of detection layers (P3-P5)
  543. g = sum(4 ** x for x in range(nl)) # grid points
  544. e = 1 # exclude layer count
  545. i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
  546. y[0] = y[0][:, :-i] # large
  547. i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
  548. y[-1] = y[-1][:, i:] # small
  549. return y
  550. Model = DetectionModel # retain YOLO 'Model' class for backwards compatibility
  551. class SegmentationModel(DetectionModel):
  552. # YOLO segmentation model
  553. def __init__(self, cfg='yolo-seg.yaml', ch=3, nc=None, anchors=None):
  554. super().__init__(cfg, ch, nc, anchors)
  555. class ClassificationModel(BaseModel):
  556. # YOLO classification model
  557. def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
  558. super().__init__()
  559. self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
  560. def _from_detection_model(self, model, nc=1000, cutoff=10):
  561. # Create a YOLO classification model from a YOLO detection model
  562. if isinstance(model, DetectMultiBackend):
  563. model = model.model # unwrap DetectMultiBackend
  564. model.model = model.model[:cutoff] # backbone
  565. m = model.model[-1] # last layer
  566. ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
  567. c = Classify(ch, nc) # Classify()
  568. c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
  569. model.model[-1] = c # replace
  570. self.model = model.model
  571. self.stride = model.stride
  572. self.save = []
  573. self.nc = nc
  574. def _from_yaml(self, cfg):
  575. # Create a YOLO classification model from a *.yaml file
  576. self.model = None
  577. def parse_model(d, ch): # model_dict, input_channels(3)
  578. # Parse a YOLO model.yaml dictionary
  579. LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
  580. anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
  581. if act:
  582. Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
  583. LOGGER.info(f"{colorstr('activation:')} {act}") # print
  584. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  585. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  586. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  587. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  588. m = eval(m) if isinstance(m, str) else m # eval strings
  589. for j, a in enumerate(args):
  590. with contextlib.suppress(NameError):
  591. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  592. n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
  593. if m in {
  594. Conv, AConv, ConvTranspose,
  595. Bottleneck, SPP, SPPF, DWConv, BottleneckCSP, nn.ConvTranspose2d, DWConvTranspose2d, SPPCSPC, ADown,
  596. RepNCSPELAN4, SPPELAN}:
  597. c1, c2 = ch[f], args[0]
  598. if c2 != no: # if not output
  599. c2 = make_divisible(c2 * gw, 8)
  600. if m in (RepNCSPELAN4, ):
  601. args[1] = make_divisible(args[1] * gw, 8)
  602. args[2] = make_divisible(args[2] * gw, 8)
  603. args = [c1, c2, *args[1:]]
  604. if m in {BottleneckCSP, SPPCSPC}:
  605. args.insert(2, n) # number of repeats
  606. n = 1
  607. elif m is nn.BatchNorm2d:
  608. args = [ch[f]]
  609. elif m is Concat:
  610. c2 = sum(ch[x] for x in f)
  611. elif m is Shortcut:
  612. c2 = ch[f[0]]
  613. elif m is ReOrg:
  614. c2 = ch[f] * 4
  615. elif m is CBLinear:
  616. c2 = [make_divisible(i * gw, 8) for i in args[0]]
  617. c1 = ch[f]
  618. args = [c1, c2, *args[1:]]
  619. elif m is CBFuse:
  620. c2 = ch[f[-1]]
  621. # TODO: channel, gw, gd
  622. elif m in {Detect, DualDetect, TripleDetect, DDetect, DualDDetect, TripleDDetect, Segment}:
  623. args.append([ch[x] for x in f])
  624. # if isinstance(args[1], int): # number of anchors
  625. # args[1] = [list(range(args[1] * 2))] * len(f)
  626. if m in {Segment}:
  627. args[2] = make_divisible(args[2] * gw, 8)
  628. elif m is Contract:
  629. c2 = ch[f] * args[0] ** 2
  630. elif m is Expand:
  631. c2 = ch[f] // args[0] ** 2
  632. else:
  633. c2 = ch[f]
  634. m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  635. t = str(m)[8:-2].replace('__main__.', '') # module type
  636. np = sum(x.numel() for x in m_.parameters()) # number params
  637. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  638. LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
  639. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  640. layers.append(m_)
  641. if i == 0:
  642. ch = []
  643. ch.append(c2)
  644. return nn.Sequential(*layers), sorted(save)
  645. if __name__ == '__main__':
  646. parser = argparse.ArgumentParser()
  647. parser.add_argument('--cfg', type=str, default='yolo.yaml', help='model.yaml')
  648. parser.add_argument('--batch-size', type=int, default=1, help='total batch size for all GPUs')
  649. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  650. parser.add_argument('--profile', action='store_true', help='profile model speed')
  651. parser.add_argument('--line-profile', action='store_true', help='profile model speed layer by layer')
  652. parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
  653. opt = parser.parse_args()
  654. opt.cfg = check_yaml(opt.cfg) # check YAML
  655. print_args(vars(opt))
  656. device = select_device(opt.device)
  657. # Create model
  658. im = torch.rand(opt.batch_size, 3, 640, 640).to(device)
  659. model = Model(opt.cfg).to(device)
  660. model.eval()
  661. # Options
  662. if opt.line_profile: # profile layer by layer
  663. model(im, profile=True)
  664. elif opt.profile: # profile forward-backward
  665. results = profile(input=im, ops=[model], n=3)
  666. elif opt.test: # test all models
  667. for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
  668. try:
  669. _ = Model(cfg)
  670. except Exception as e:
  671. print(f'Error in {cfg}: {e}')
  672. else: # report fused model summary
  673. model.fuse()

        最后,修改模型配置文件的深度与宽度,运行即可。运行报错的可以看一下我之前的文章,或者评论区提问。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/312799
推荐阅读
相关标签
  

闽ICP备14008679号