当前位置:   article > 正文

yolov5剪枝 实战_yolov5模型剪枝的代码

yolov5模型剪枝的代码

(1)步骤

剪枝的一般步骤只是在正常训练的后面加上了稀疏化训练和剪枝的步骤。

(2)稀疏化训练

在这里插入图片描述

主要区别

稀疏化训练的代码和正常训练的代码的差别主要体现在
①反向传播 ②优化器 ③parse_opt代码

接下来从代码执行训练简单分析
(下面代码均为稀疏化训练的代码)

(1)parse_opt代码

加入了这两行!!!

    parser.add_argument('--st', action='store_true',default=True, help='train with L1 sparsity normalization')
    parser.add_argument('--sr', type=float, default=0.0001, help='L1 normal sparse rate')
  • 1
  • 2

sr:平衡因子lamda(就是论文里面的这个红圈里面的东西)
在这里插入图片描述

(2)反向传播

train.py

# Backward
            scaler.scale(loss).backward()     
  • 1
  • 2

train_sparsity.py

            loss.backward()
  • 1

③优化器

     # Optimize
            if ni - last_opt_step >= accumulate:
                scaler.step(optimizer)  # optimizer.step
                scaler.update()
                optimizer.zero_grad()
                if ema:
                    ema.update(model)
                last_opt_step = ni
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
            # # ============================= sparsity training ========================== #
            srtmp = opt.sr*(1 - 0.9*epoch/epochs)
            if opt.st:
                ignore_bn_list = []
                for k, m in model.named_modules():
                    if isinstance(m, Bottleneck):
                        if m.add:
                            ignore_bn_list.append(k.rsplit(".", 2)[0] + ".cv1.bn")
                            ignore_bn_list.append(k + '.cv1.bn')
                            ignore_bn_list.append(k + '.cv2.bn')
                    if isinstance(m, nn.BatchNorm2d) and (k not in ignore_bn_list):
                        m.weight.grad.data.add_(srtmp * torch.sign(m.weight.data))  # L1
                        m.bias.grad.data.add_(opt.sr*10 * torch.sign(m.bias.data))  # L1
            # # ============================= sparsity training ========================== #

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

1)for k, m in model.named_modules() 此处使用k,m两个变量是因为model.named_modules()在网络中的所有模块上返回一个迭代器,该迭代器不仅包含模块名称,还包含模块本身。因此在这段代码中k应该是模块名称,m应该是模块本身

2)不进行稀疏化的内容
图片来自知乎的博主
在这里插入图片描述
3)调节权重和bias

                        m.weight.grad.data.add_(srtmp * torch.sign(m.weight.data))  # L1
                        m.bias.grad.data.add_(opt.sr*10 * torch.sign(m.bias.data))  # L1
  • 1
  • 2

直接给出稀疏化训练的完整代码:

# YOLOv5 
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/运维做开发/article/detail/961812
推荐阅读
相关标签