赞
踩
yolov7代码中的 train.py
, test.py
要了解
因为我们剪枝进行 finetune 的时候需要train()
这个函数,prune 的时候需要test()
这个函数
剪枝是循序渐进的过程,有step的一点一点的剪枝,而不是上来就剪掉50%。
剪枝 -> finetune -> 剪枝 -> finetune … 直到 满足你的需求(剪到模型足够小了,计算量足够低了),就可以停下while循环了。
如下图所示:
大家可以看一下 Torch-Pruning 的作者,对工具底层的解释:Torch-Pruning | 轻松实现结构化剪枝算法
Torch-Pruning的ResNet18 简单示例:
import torch from torchvision.models import resnet18 import torch_pruning as tp model = resnet18(pretrained=True) example_inputs = torch.randn(1, 3, 224, 224) # 1. 选择合适的重要性评估指标,这里使用权值大小 imp = tp.importance.MagnitudeImportance(p=2) # 2. 忽略无需剪枝的层,例如最后的分类层(总不能剪完类别都变少了叭?) ignored_layers = [] for m in model.modules(): if isinstance(m, torch.nn.Linear) and m.out_features == 1000: ignored_layers.append(m) # DO NOT prune the final classifier! # 3. 初始化剪枝器 iterative_steps = 5 # 迭代式剪枝,重复5次Pruning-Finetuning的循环完成剪枝。 pruner = tp.pruner.MagnitudePruner( model, example_inputs, # 用于分析依赖的伪输入 importance=imp, # 重要性评估指标 iterative_steps=iterative_steps, # 迭代剪枝,设为1则一次性完成剪枝 ch_sparsity=0.5, # 目标稀疏性,这里我们移除50%的通道 ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} ignored_layers=ignored_layers, # 忽略掉最后的分类层 ) # 4. Pruning-Finetuning的循环 base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) for i in range(iterative_steps): pruner.step() macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) # finetune your model here # finetune(model) # ...
首先我们选择合适的参数重要性评估(tp.importance),然后创建合适的剪枝器(tp.pruner),设置合适的稀疏度(=剪枝率),最后调用pruner.step()就完成了剪枝。以上代码就是利用Torch-Pruning对任意模型剪枝的基本流程。通过打印剪枝后的模型我们可以看到,在第5次迭代完成后,所有层的通道数都被正确地调整到了一半 (512 => 256)
在示例中,我们选择
1. MagnitudePruner基于权重大小的剪枝,方法是L2 normalization
2. 此方法不需要进行 稀疏化训练(Sparity Learning)
3. 剪枝比率:计算量减少4倍:因为我想要的是计算量减少4倍为标准,所以这个地方和示例写的不一样哦。
import torch from torchvision.models import resnet18 import torch_pruning as tp model = resnet18(pretrained=True) example_inputs = torch.randn(1, 3, 224, 224) # 1. 选择合适的重要性评估指标,这里使用权值大小 imp = tp.importance.MagnitudeImportance(p=2) # 2. 忽略无需剪枝的层,例如最后的分类层(总不能剪完类别都变少了叭?) unwrapped_parameters = [] ignored_layers = [] ch_sparsity_dict = {} customized_pruners = {} ignored_layers = [] # for cfg/training/yolov7-tiny.yaml ignored_layers.append(model.model[105]) # DO NOT prune the IDETECT # 3. 初始化剪枝器 example_inputs = torch.randn((1, 3, imgsz, imgsz)).to(device) pruner = pruner_entry( model, example_inputs, # 用于分析依赖的伪输入 importance=imp, # 重要性评估指标 iterative_steps=opt.iterative_steps, # 迭代剪枝,设为1则一次性完成剪枝 ch_sparsity=1.0, # 目标稀疏性 # ch_sparsity_dict=ch_sparsity_dict, # max_ch_sparsity=opt.max_sparsity, ignored_layers=ignored_layers, # # 忽略掉最后的分类层 # unwrapped_parameters=unwrapped_parameters, # customized_pruners=customized_pruners, root_module_types=[nn.Conv2d, nn.Linear] ) # 4. Pruning-Finetuning的循环 def model_prune(opt, model, prune, example_inputs, testloader, imgsz_test): model.eval() base_model = copy.deepcopy(model) with HiddenPrints(): ori_flops, ori_params = tp.utils.count_ops_and_params(base_model, example_inputs) ori_flops = ori_flops * 2.0 ori_flops_f, ori_params_f = clever_format([ori_flops, ori_params], "%.3f") ori_result, _, _ = test.test(opt.data, None, batch_size=opt.batch_size * 2, imgsz=imgsz_test, plots=False, model=base_model, dataloader=testloader) # def test(): # return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t ori_map50, ori_map = ori_result[2], ori_result[3] iter_idx, prune_flops = 0, ori_flops speed_up = 1.0 logger.info('begin pruning...') while speed_up < opt.speed_up: iter_idx += 1 prune.step(interactive=False) prune_result, _, _ = test.test(opt.data, None, batch_size=opt.batch_size * 2, imgsz=imgsz_test, plots=False, model=model, dataloader=testloader) prune_map50, prune_map = prune_result[2], prune_result[3] with HiddenPrints(): prune_flops, prune_params = tp.utils.count_ops_and_params(model, example_inputs) prune_flops = prune_flops * 2.0 prune_flops_f, prune_params_f = clever_format([prune_flops, prune_params], "%.3f") speed_up = ori_flops / prune_flops # ori_model_GFLOPs / prune_model_GFLOPs if prune.current_step == prune.iterative_steps: break return model # 5. finetune(model) finetune(opt, model, dataloader, testloader, device) # 和 train()差不多,改一改就行了
detect.py
代码里面,就是一个剪枝的步骤 + 剪枝模型的detecttrain.py
代码里面,就是剪枝+ finetune# Load model model = attempt_load(weights, map_location=device) # load FP32 model print(model) ################################################################################ # Pruning example_inputs = torch.randn(1, 3, 224, 224).to(device) imp = tp.importance.MagnitudeImportance(p=2) # L2 norm pruning ignored_layers = [] from models.yolo import Detect for m in model.modules(): if isinstance(m, Detect): ignored_layers.append(m) print(ignored_layers) iterative_steps = 1 # progressive pruning pruner = tp.pruner.MagnitudePruner( model, example_inputs, importance=imp, iterative_steps=iterative_steps, ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} ignored_layers=ignored_layers, ) base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) pruner.step() pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(model, example_inputs) print(model) print("Before Pruning: MACs=%f G, #Params=%f G"%(base_macs/1e9, base_nparams/1e9)) print("After Pruning: MACs=%f G, #Params=%f G"%(pruned_macs/1e9, pruned_nparams/1e9)) ####################################################################################
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。