当前位置:   article > 正文

YOLOv8实战-模型剪枝_yolov8剪枝

yolov8剪枝

        模型剪枝是用在模型的一种优化技术,旨在减少神经网络中不必要的参数,从而降低模型的复杂性和计算负载,进一步提高模型的效率。

        模型剪枝的流程:约束训练(constained training)、剪枝(prune)、回调训练(finetune)

        本篇主要记录自己YOLOv8模型剪枝的全过程,主要参考:YOLOv8剪枝全过程

目 录

一、约束训练(constrained training)

1、参数设置

2、稀疏训练

二、剪枝(prune)

三、回调训练(finetune)

1、代码修改

2、再训练


一、约束训练(constrained training)

1、参数设置

         设置./ultralytics/cfg/default.yaml中的amp=False

2、稀疏训练

        主要方式:在BN层添加L1正则化

        具体步骤:在./ultralytics/engine/trainer.py中添加以下内容:

  1. # Backward
  2. self.scaler.scale(self.loss).backward()
  3. # ========== added(新增) ==========
  4. # 1 constrained training
  5. l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
  6. for k, m in self.model.named_modules():
  7. if isinstance(m, nn.BatchNorm2d):
  8. m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
  9. m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))
  10. # ========== added(新增) ==========
  11. # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
  12. if ni - last_opt_step >= self.accumulate:
  13. self.optimizer_step()
  14. last_opt_step = ni

        然后启动训练(/yolov8/train.py):

  1. from ultralytics import YOLO
  2. model = YOLO('yolov8n.yaml')
  3. results = model.train(data='./data/data_nc5/data_nc5.yaml', batch=8, epochs=300, save=True)

二、剪枝(prune)

        一该部分选用上一步训练得到的模型./runs/detect/train2/weight/last.pt进行剪枝处理。在/yolov8/下新建文件prune.py,具体内容如下:

  1. from ultralytics import YOLO
  2. import torch
  3. from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
  4. # Load a model
  5. yolo = YOLO("./runs/detect/train2/weights/last.pt")
  6. model = yolo.model
  7. ws = []
  8. bs = []
  9. for name, m in model.named_modules():
  10. if isinstance(m, torch.nn.BatchNorm2d):
  11. w = m.weight.abs().detach()
  12. b = m.bias.abs().detach()
  13. ws.append(w)
  14. bs.append(b)
  15. # print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
  16. # keep
  17. factor = 0.8
  18. ws = torch.cat(ws)
  19. threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
  20. print(threshold)
  21. def prune_conv(conv1: Conv, conv2: Conv):
  22. gamma = conv1.bn.weight.data.detach()
  23. beta = conv1.bn.bias.data.detach()
  24. keep_idxs = []
  25. local_threshold = threshold
  26. while len(keep_idxs) < 8:
  27. keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
  28. local_threshold = local_threshold * 0.5
  29. n = len(keep_idxs)
  30. # n = max(int(len(idxs) * 0.8), p)
  31. # print(n / len(gamma) * 100)
  32. # scale = len(idxs) / n
  33. conv1.bn.weight.data = gamma[keep_idxs]
  34. conv1.bn.bias.data = beta[keep_idxs]
  35. conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
  36. conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
  37. conv1.bn.num_features = n
  38. conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
  39. conv1.conv.out_channels = n
  40. if conv1.conv.bias is not None:
  41. conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]
  42. if not isinstance(conv2, list):
  43. conv2 = [conv2]
  44. for item in conv2:
  45. if item is not None:
  46. if isinstance(item, Conv):
  47. conv = item.conv
  48. else:
  49. conv = item
  50. conv.in_channels = n
  51. conv.weight.data = conv.weight.data[:, keep_idxs]
  52. def prune(m1, m2):
  53. if isinstance(m1, C2f): # C2f as a top conv
  54. m1 = m1.cv2
  55. if not isinstance(m2, list): # m2 is just one module
  56. m2 = [m2]
  57. for i, item in enumerate(m2):
  58. if isinstance(item, C2f) or isinstance(item, SPPF):
  59. m2[i] = item.cv1
  60. prune_conv(m1, m2)
  61. for name, m in model.named_modules():
  62. if isinstance(m, Bottleneck):
  63. prune_conv(m.cv1, m.cv2)
  64. seq = model.model
  65. for i in range(3, 9):
  66. if i in [6, 4, 9]: continue
  67. prune(seq[i], seq[i + 1])
  68. detect: Detect = seq[-1]
  69. last_inputs = [seq[15], seq[18], seq[21]]
  70. colasts = [seq[16], seq[19], None]
  71. for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
  72. prune(last_input, [colast, cv2[0], cv3[0]])
  73. prune(cv2[0], cv2[1])
  74. prune(cv2[1], cv2[2])
  75. prune(cv3[0], cv3[1])
  76. prune(cv3[1], cv3[2])
  77. for name, p in yolo.model.named_parameters():
  78. p.requires_grad = True
  79. yolo.val() # 剪枝模型进行验证 yolo.val(workers=0)
  80. yolo.export(format="onnx") # 导出为onnx文件
  81. # yolo.train(data="./data/data_nc5/data_nc5.yaml", epochs=100) # 剪枝后直接训练微调
  82. torch.save(yolo.ckpt, "./runs/detect/train2/weights/prune.pt")
  83. print("done")

其中,factor=0.8 表示的是保持率,factor越小,裁剪的就越多,一般不建议裁剪太多。

        运行prune.py,可得到剪枝后的模型prune.pt,保存在./runs/detect/train2/weight/中。同文件夹下,还有last.onnx,可以看到onnx文件的大小比剪枝前变小了,具体结构(onnx模型结构查看)也和剪枝前的onnx相比有了轻微变化。

三、回调训练(finetune)

1、代码修改

        首先,将先前在./ultralytics/engine/trainer.py中添加的L1正则化部分注释掉:

  1. # Backward
  2. self.scaler.scale(self.loss).backward()
  3. # # ========== added(新增) ==========
  4. # # 1 constrained training
  5. # l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
  6. # for k, m in self.model.named_modules():
  7. # if isinstance(m, nn.BatchNorm2d):
  8. # m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
  9. # m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))
  10. # # ========== added(新增) ==========
  11. # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
  12. if ni - last_opt_step >= self.accumulate:
  13. self.optimizer_step()
  14. last_opt_step = ni

        然后,在该文件第543行左右添加代码 “self.model = weights” :

  1. def setup_model(self):
  2. """Load/create/download model for any task."""
  3. if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
  4. return
  5. model, weights = self.model, None
  6. ckpt = None
  7. if str(model).endswith(".pt"):
  8. weights, ckpt = attempt_load_one_weight(model)
  9. cfg = weights.yaml
  10. else:
  11. cfg = model
  12. self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
  13. # ========== added(新增) ==========
  14. # 2 finetune 回调训练
  15. self.model = weights
  16. # ========== added(新增) ==========
  17. return ckpt

2、再训练

         利用已经剪枝好的模型prune.pt,我们再次启动训练(/yolov8/train.py):

  1. from ultralytics import YOLO
  2. model = YOLO('./runs/detect/train5/weights/prune.pt')
  3. results = model.train(data='./data/data_nc5/data_nc5.yaml', batch=8, epochs=100, save=True)

注意,这里把model改成了"prune.pt",而不是原来的"yolov8n.yaml"

        训练后新的模型保存在“./runs/detect/train3/weight/”中。后面可按需要进一步进行模型的推理和部署。

下一篇:YOLOv8实战-模型推理及部署-CSDN博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/木道寻08/article/detail/961787
推荐阅读
  

闽ICP备14008679号