当前位置:   article > 正文

基于YOLOv5n/s/m不同参数量级模型开发构建茶叶嫩芽检测识别模型,使用pruning剪枝技术来对模型进行轻量化处理,探索不同剪枝水平下模型性能影响_yolov5 torch-pruning

yolov5 torch-pruning

今天有点时间就想着之前遗留的一个问题正好拿过来做一下看看,主要的目的就是想要对训练好的目标检测模型进行剪枝处理,这里就以茶叶嫩芽检测数据场景为例了,在我前面的博文中已经有过相关的实践介绍了,感兴趣的话可以自行移步阅读即可:

《融合CBAM注意力机制基于YOLOv5开发构建毛尖茶叶嫩芽检测识别系统》

这里就不再赘述了。

本文选取了n/s/m三款不同量级的模型来依次构建训练模型,所有的参数保持同样的设置,之后探索在不同剪枝处理操作下的性能影响。

简单看下数据集情况:

 三款模型的训练指令如下所示:

  1. #yolov5n
  2. python3 train.py --cfg models/yolov5n.yaml --weights weights/yolov5n.pt --name yolov5n --epochs 100 --batch-size 4 --img-size 416
  3. #yolov5s
  4. python3 train.py --cfg models/yolov5s.yaml --weights weights/yolov5s.pt --name yolov5s --epochs 100 --batch-size 4 --img-size 416
  5. #yolov5m
  6. python3 train.py --cfg models/yolov5m.yaml --weights weights/yolov5m.pt --name yolov5m --epochs 100 --batch-size 4 --img-size 416

主要是两点,一是batchsize这里设置的比较小因为同时在跑三款模型,这里设置的都是4;另一方面是imgsize,这里为了加快实验节奏,设置的是416,比较低的分辨率而不是640。

默认都是100次epoch的迭代计算,接下来依次看下实际训练情况:
【yolov5n】

 【yolov5s】

 【yolov5m】

 从最终模型的评估结果上面来看:s系列的模型结果还不如n系列的模型,或者说是二者差异不大,m系列模型的结果要优于其他两款模型。

为了能够整体直观地对三款不同参数量级的模型进行直观地对比分析,这里对其主要指标进行了可视化处理,如下所示:

【Precision曲线】
精确率曲线(Precision-Recall Curve)是一种用于评估二分类模型在不同阈值下的精确率性能的可视化工具。它通过绘制不同阈值下的精确率和召回率之间的关系图来帮助我们了解模型在不同阈值下的表现。
精确率(Precision)是指被正确预测为正例的样本数占所有预测为正例的样本数的比例。召回率(Recall)是指被正确预测为正例的样本数占所有实际为正例的样本数的比例。
绘制精确率曲线的步骤如下:
使用不同的阈值将预测概率转换为二进制类别标签。通常,当预测概率大于阈值时,样本被分类为正例,否则分类为负例。
对于每个阈值,计算相应的精确率和召回率。
将每个阈值下的精确率和召回率绘制在同一个图表上,形成精确率曲线。
根据精确率曲线的形状和变化趋势,可以选择适当的阈值以达到所需的性能要求。
通过观察精确率曲线,我们可以根据需求确定最佳的阈值,以平衡精确率和召回率。较高的精确率意味着较少的误报,而较高的召回率则表示较少的漏报。根据具体的业务需求和成本权衡,可以在曲线上选择合适的操作点或阈值。
精确率曲线通常与召回率曲线(Recall Curve)一起使用,以提供更全面的分类器性能分析,并帮助评估和比较不同模型的性能。


【Recall曲线】
召回率曲线(Recall Curve)是一种用于评估二分类模型在不同阈值下的召回率性能的可视化工具。它通过绘制不同阈值下的召回率和对应的精确率之间的关系图来帮助我们了解模型在不同阈值下的表现。
召回率(Recall)是指被正确预测为正例的样本数占所有实际为正例的样本数的比例。召回率也被称为灵敏度(Sensitivity)或真正例率(True Positive Rate)。
绘制召回率曲线的步骤如下:
使用不同的阈值将预测概率转换为二进制类别标签。通常,当预测概率大于阈值时,样本被分类为正例,否则分类为负例。
对于每个阈值,计算相应的召回率和对应的精确率。
将每个阈值下的召回率和精确率绘制在同一个图表上,形成召回率曲线。
根据召回率曲线的形状和变化趋势,可以选择适当的阈值以达到所需的性能要求。
通过观察召回率曲线,我们可以根据需求确定最佳的阈值,以平衡召回率和精确率。较高的召回率表示较少的漏报,而较高的精确率意味着较少的误报。根据具体的业务需求和成本权衡,可以在曲线上选择合适的操作点或阈值。
召回率曲线通常与精确率曲线(Precision Curve)一起使用,以提供更全面的分类器性能分析,并帮助评估和比较不同模型的性能。


【F1值曲线】
F1值曲线是一种用于评估二分类模型在不同阈值下的性能的可视化工具。它通过绘制不同阈值下的精确率(Precision)、召回率(Recall)和F1分数的关系图来帮助我们理解模型的整体性能。
F1分数是精确率和召回率的调和平均值,它综合考虑了两者的性能指标。F1值曲线可以帮助我们确定在不同精确率和召回率之间找到一个平衡点,以选择最佳的阈值。
绘制F1值曲线的步骤如下:
使用不同的阈值将预测概率转换为二进制类别标签。通常,当预测概率大于阈值时,样本被分类为正例,否则分类为负例。
对于每个阈值,计算相应的精确率、召回率和F1分数。
将每个阈值下的精确率、召回率和F1分数绘制在同一个图表上,形成F1值曲线。
根据F1值曲线的形状和变化趋势,可以选择适当的阈值以达到所需的性能要求。
F1值曲线通常与接收者操作特征曲线(ROC曲线)一起使用,以帮助评估和比较不同模型的性能。它们提供了更全面的分类器性能分析,可以根据具体应用场景来选择合适的模型和阈值设置。

 【loss曲线】

 整体来看m系列模型无意识这三款不同参数量级模型中表现最好的,n和s系列模型的表现相近。

接下来要对三款模型进行剪枝处理,这里使用到一个很好用的第三方模块torch_pruning,官方项目地址在这里,如下所示:

 安装方式很简单如下所示:

  1. pip install torch-pruning
  2. 或者
  3. git clone https://github.com/VainF/Torch-Pruning.git

在结构修剪中,“组”被定义为可以在深度网络中移除的最小单元。这些组由多个相互依赖的层组成,需要一起修剪,以保持生成结构的完整性。然而,深度网络的层之间往往存在复杂的依赖关系,这使得结构修剪成为一项具有挑战性的任务。这项工作通过引入一种名为“DepGraph”的自动化机制来解决这一挑战。DepGraph允许无缝的参数分组,并有助于在各种类型的深度网络中进行修剪。

官方提供很多可用的实例,如下所示:

【Naive pruning】

为了演示依赖性的含义,让我们在ResNet-18上尝试结构化修剪。以下代码片段尝试从第一个model.conv1中删除由0和1索引的通道:

  1. from torchvision.models import resnet18
  2. import torch_pruning as tp
  3. model = resnet18(pretrained=True).eval()
  4. tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) # remove channel 0 and channel 1
  5. output = model(torch.randn(1,3,224,224)) # test
  6. 输出
  7. ResNet(
  8. (conv1): Conv2d(3, 62, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  9. (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  10. (relu): ReLU(inplace=True)
  11. (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  12. (layer1): Sequential(
  13. (0): BasicBlock(
  14. (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  15. (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  16. (relu): ReLU(inplace=True)
  17. (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  18. (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  19. )
  20. ...

【An improved version】

事实上,上述情况下的依赖关系比我们已经观察到的要复杂得多。让我们改进我们的代码,看看如果处理BN和Conv会发生什么。

  1. from torchvision.models import resnet18
  2. import torch_pruning as tp
  3. model = resnet18(pretrained=True).eval()
  4. tp.prune_conv_out_channels(model.conv1, idxs=[0,1])
  5. tp.prune_batchnorm_out_channels(model.bn1, idxs=[0,1])
  6. tp.prune_batchnorm_in_channels(model.layer1[0].conv1, idxs=[0,1])
  7. output = model(torch.randn(1,3,224,224))

【A Minimal Example】

  1. import torch
  2. from torchvision.models import resnet18
  3. import torch_pruning as tp
  4. model = resnet18(pretrained=True).eval()
  5. # 1. build dependency graph for resnet18
  6. DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))
  7. # 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
  8. group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )
  9. # 3. prune all grouped layers that are coupled with model.conv1 (included).
  10. if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.
  11. group.prune()
  12. # 4. Save & Load
  13. model.zero_grad() # We don't want to store gradient information
  14. torch.save(model, 'model.pth') # without .state_dict
  15. model = torch.load('model.pth') # load the model object
  16. 上面的示例演示了使用DepGraph的基本修剪管道。目标层resnet.conv1与多个层耦合,这需要在结构修剪中同时移除。让我们打印该组,并观察修剪操作是如何“触发”其他修剪操作的。在以下输出中,A=>B表示修剪操作A触发修剪操作B。group[0]表示DG.get_pruning_group中的修剪根。
  17. --------------------------------
  18. Pruning Group
  19. --------------------------------
  20. [0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root)
  21. [1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
  22. [2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs=[2, 6, 9]
  23. [3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9]
  24. [4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), idxs=[2, 6, 9]
  25. [5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
  26. [6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
  27. [7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), idxs=[2, 6, 9]
  28. [8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), idxs=[2, 6, 9]
  29. [9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
  30. [10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
  31. [11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), idxs=[2, 6, 9]
  32. [12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs=[2, 6, 9]
  33. [13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
  34. [14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
  35. [15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
  36. --------------------------------

【High-level Pruners】

利用DependencyGraph,我们在这个存储库中开发了几个高级修剪器,以方便轻松修剪。通过指定所需的通道稀疏性,可以修剪整个模型,并使用自己的训练代码对其进行微调。有关这个过程的详细信息,请参阅本教程,它展示了如何从头开始实现瘦身修剪器。此外,您可以在benchmarks/main.py中找到更实用的示例。

  1. import torch
  2. from torchvision.models import resnet18
  3. import torch_pruning as tp
  4. model = resnet18(pretrained=True)
  5. # Importance criteria
  6. example_inputs = torch.randn(1, 3, 224, 224)
  7. imp = tp.importance.TaylorImportance()
  8. ignored_layers = []
  9. for m in model.modules():
  10. if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
  11. ignored_layers.append(m) # DO NOT prune the final classifier!
  12. iterative_steps = 5 # progressive pruning
  13. pruner = tp.pruner.MagnitudePruner(
  14. model,
  15. example_inputs,
  16. importance=imp,
  17. iterative_steps=iterative_steps,
  18. ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
  19. ignored_layers=ignored_layers,
  20. )
  21. base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
  22. for i in range(iterative_steps):
  23. if isinstance(imp, tp.importance.TaylorImportance):
  24. # Taylor expansion requires gradients for importance estimation
  25. loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
  26. loss.backward() # before pruner.step()
  27. pruner.step()
  28. macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
  29. # finetune your model here
  30. # finetune(model)
  31. # ...

还有很多其他的功能实例,这里就不再一一赘述了,可以参考使用即可,这里我借鉴官方的实例来完成对yolov5n/s/m三款不同参数量级模型的剪枝处理。剪枝完成后结果如下所示:

 接下来我想直接使用剪枝后的模型来进行评估测试,不出意外的话结果应该会很差的,先来简单看下吧。

【yolov5n_layer_pruning】

 【yolov5s_layer_pruning】

 【yolov5m_layer_pruning】

 果然是惨不忍睹,直接使用剪枝后的模型文件是不行的,这样破坏了原始完整的模型结构,导致原有学习后的知识已经无效了。

接下来就需要基于剪枝后的结构来进行微调训练。这里我同样保持了与最初模型训练一样的参数设置,如下所示:

  1. #yolov5n
  2. python3 train.py --weights yolov5n_layer_pruning.pt --pt --name yolov5n_pruning --epochs 100 --batch-size 4 --img-size 416
  3. #yolov5s
  4. python3 train.py --weights yolov5s_layer_pruning.pt --pt --name yolov5s_pruning --epochs 100 --batch-size 4 --img-size 416
  5. #yolov5m
  6. python3 train.py --weights yolov5m_layer_pruning.pt --pt --name yolov5m_pruning --epochs 100 --batch-size 4 --img-size 416

这里其实也可以不用训练100次epoch,只不过我想默认保持一样的参数设置,等待一段时间后来看下结果记录。

【yolov5n_pruning】

 【yolov5s_pruning】

 【yolov5m_pruning】

 这里从评估结果上来看:n<s<m。接下来我们同样对其进行对比可视化分析展示。

【F1值】

 【精确率】

 【召回率】

 【loss】

 上述的三组剪枝实验结果是建立在剪枝30%的基础上,产生的结果,可以看到:甚至剪枝后的效果还要优于原始的模型,这也说明了原始的模型中存在相当量的参数冗余。

接下来我们想要进一步探索不同程度剪枝水平对于模型性能的影响程度,前车之鉴,这里写CSDN博文都不敢一篇文章写太多内容,不然突然页面崩溃就会好心酸。。。。。。

我把这部分的内容放在下一篇博文中,如下所示:
《基于YOLOv5n/s/m不同参数量级模型开发构建茶叶嫩芽检测识别模型,预计pruning剪枝技术来对模型进行轻量化处理,探索不同剪枝水平下模型性能影响【续】》

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/147460
推荐阅读
相关标签
  

闽ICP备14008679号