当前位置:   article > 正文

yolov7示例 | 如何写一个剪枝代码?_编程 剪枝

编程 剪枝

前言:知识储备

  1. 剪枝的系统介绍

    • 首先, 选择剪枝的颗粒度:规律 or 不规则
    • 然后, 选择在哪里剪枝:权重 or 结构
    • 其次,选择剪枝程度:计算量减少5倍?
  2. yolov7代码中的 train.py, test.py要了解

    因为我们剪枝进行 finetune 的时候需要train()这个函数,prune 的时候需要test()这个函数

1. 剪枝流程

剪枝是循序渐进的过程,有step的一点一点的剪枝,而不是上来就剪掉50%。
剪枝 -> finetune -> 剪枝 -> finetune … 直到 满足你的需求(剪到模型足够小了,计算量足够低了),就可以停下while循环了。

如下图所示:
在这里插入图片描述
在这里插入图片描述

2. 剪枝工具 Torch-Pruning

大家可以看一下 Torch-Pruning 的作者,对工具底层的解释:Torch-Pruning | 轻松实现结构化剪枝算法

Torch-Pruning的ResNet18 简单示例:

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

# 1. 选择合适的重要性评估指标,这里使用权值大小
imp = tp.importance.MagnitudeImportance(p=2)

# 2. 忽略无需剪枝的层,例如最后的分类层(总不能剪完类别都变少了叭?)
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

# 3. 初始化剪枝器
iterative_steps = 5 # 迭代式剪枝,重复5次Pruning-Finetuning的循环完成剪枝。
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs, # 用于分析依赖的伪输入
    importance=imp, # 重要性评估指标
    iterative_steps=iterative_steps, # 迭代剪枝,设为1则一次性完成剪枝
    ch_sparsity=0.5, # 目标稀疏性,这里我们移除50%的通道 ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers, # 忽略掉最后的分类层
)

# 4. Pruning-Finetuning的循环
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    # finetune your model here
    # finetune(model)
    # ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

首先我们选择合适的参数重要性评估(tp.importance),然后创建合适的剪枝器(tp.pruner),设置合适的稀疏度(=剪枝率),最后调用pruner.step()就完成了剪枝。以上代码就是利用Torch-Pruning对任意模型剪枝的基本流程。通过打印剪枝后的模型我们可以看到,在第5次迭代完成后,所有层的通道数都被正确地调整到了一半 (512 => 256)

3. yolov7 剪枝示例 | 伪代码

在示例中,我们选择
1. MagnitudePruner基于权重大小的剪枝,方法是L2 normalization
2. 此方法不需要进行 稀疏化训练(Sparity Learning)
3. 剪枝比率:计算量减少4倍:因为我想要的是计算量减少4倍为标准,所以这个地方和示例写的不一样哦。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

# 1. 选择合适的重要性评估指标,这里使用权值大小
imp = tp.importance.MagnitudeImportance(p=2)

# 2. 忽略无需剪枝的层,例如最后的分类层(总不能剪完类别都变少了叭?)
 unwrapped_parameters = []
 ignored_layers = []
 ch_sparsity_dict = {}
 customized_pruners = {}

ignored_layers = []
# for cfg/training/yolov7-tiny.yaml
ignored_layers.append(model.model[105])  # DO NOT prune the IDETECT

# 3. 初始化剪枝器
example_inputs = torch.randn((1, 3, imgsz, imgsz)).to(device)
pruner = pruner_entry(
    model,
    example_inputs, # 用于分析依赖的伪输入
    importance=imp, # 重要性评估指标
    iterative_steps=opt.iterative_steps, # 迭代剪枝,设为1则一次性完成剪枝
    ch_sparsity=1.0, # 目标稀疏性
    # ch_sparsity_dict=ch_sparsity_dict,
    # max_ch_sparsity=opt.max_sparsity,
    ignored_layers=ignored_layers, #  # 忽略掉最后的分类层
    # unwrapped_parameters=unwrapped_parameters,
    # customized_pruners=customized_pruners,
    root_module_types=[nn.Conv2d, nn.Linear]
)

# 4. Pruning-Finetuning的循环
def model_prune(opt, model, prune, example_inputs, testloader, imgsz_test):
    model.eval()
    base_model = copy.deepcopy(model)
    with HiddenPrints():
        ori_flops, ori_params = tp.utils.count_ops_and_params(base_model, example_inputs)
    ori_flops = ori_flops * 2.0
    ori_flops_f, ori_params_f = clever_format([ori_flops, ori_params], "%.3f")
    ori_result, _, _ = test.test(opt.data, None, batch_size=opt.batch_size * 2,
                                 imgsz=imgsz_test, plots=False, model=base_model, dataloader=testloader)
    # def test():
    #   return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t
    ori_map50, ori_map = ori_result[2], ori_result[3]
    iter_idx, prune_flops = 0, ori_flops
    speed_up = 1.0
    logger.info('begin pruning...')
    while speed_up < opt.speed_up:
        iter_idx += 1
        prune.step(interactive=False)
        prune_result, _, _ = test.test(opt.data, None, batch_size=opt.batch_size * 2,
                                 imgsz=imgsz_test, plots=False, model=model, dataloader=testloader)
        prune_map50, prune_map = prune_result[2], prune_result[3]
        with HiddenPrints():
            prune_flops, prune_params = tp.utils.count_ops_and_params(model, example_inputs)
        prune_flops = prune_flops * 2.0
        prune_flops_f, prune_params_f = clever_format([prune_flops, prune_params], "%.3f")
        speed_up = ori_flops / prune_flops # ori_model_GFLOPs / prune_model_GFLOPs
        if prune.current_step == prune.iterative_steps:
            break
        
    return model

# 5. finetune(model)
finetune(opt, model, dataloader, testloader, device) # 和 train()差不多,改一改就行了

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71

4. 实验结果

  • 计算量减少4倍哦!
  • 掉的点,之后可以蒸馏找回精度!【后续会更新】
    在这里插入图片描述

5. 动手学习剪枝代码:完成你自己的剪枝代码

  • 把这个示例加到你的yolov7: detect.py 代码里面,就是一个剪枝的步骤 + 剪枝模型的detect
    你可以自己在上面改动,加上finetune,一步步的剪枝
  • 把这个示例加到train.py代码里面,就是剪枝+ finetune
    # Load model
    model = attempt_load(weights, map_location=device)  # load FP32 model
    print(model)
    
    ################################################################################
    # Pruning
    example_inputs = torch.randn(1, 3, 224, 224).to(device)
    imp = tp.importance.MagnitudeImportance(p=2) # L2 norm pruning

    ignored_layers = []
    from models.yolo import Detect
    for m in model.modules():
        if isinstance(m, Detect):
            ignored_layers.append(m)
    print(ignored_layers)

    iterative_steps = 1 # progressive pruning
    pruner = tp.pruner.MagnitudePruner(
        model,
        example_inputs,
        importance=imp,
        iterative_steps=iterative_steps,
        ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
        ignored_layers=ignored_layers,
    )
    base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
    

    pruner.step()
    pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(model)
    print("Before Pruning: MACs=%f G, #Params=%f G"%(base_macs/1e9, base_nparams/1e9))
    print("After Pruning: MACs=%f G, #Params=%f G"%(pruned_macs/1e9, pruned_nparams/1e9))
    ####################################################################################

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

参考

  1. Torch-Pruning | 轻松实现结构化剪枝算法
  2. 剪枝的系统介绍 | 剪枝有哪些方法
  3. Torch-Pruning example
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/690393
推荐阅读
相关标签
  

闽ICP备14008679号