赞
踩
剪枝的一般步骤只是在正常训练的后面加上了稀疏化训练和剪枝的步骤。
稀疏化训练的代码和正常训练的代码的差别主要体现在
①反向传播 ②优化器 ③parse_opt代码
接下来从代码执行训练简单分析
(下面代码均为稀疏化训练的代码)
加入了这两行!!!
parser.add_argument('--st', action='store_true',default=True, help='train with L1 sparsity normalization')
parser.add_argument('--sr', type=float, default=0.0001, help='L1 normal sparse rate')
sr:平衡因子lamda(就是论文里面的这个红圈里面的东西)
train.py
# Backward
scaler.scale(loss).backward()
train_sparsity.py
loss.backward()
③优化器
# Optimize
if ni - last_opt_step >= accumulate:
scaler.step(optimizer) # optimizer.step
scaler.update()
optimizer.zero_grad()
if ema:
ema.update(model)
last_opt_step = ni
# # ============================= sparsity training ========================== #
srtmp = opt.sr*(1 - 0.9*epoch/epochs)
if opt.st:
ignore_bn_list = []
for k, m in model.named_modules():
if isinstance(m, Bottleneck):
if m.add:
ignore_bn_list.append(k.rsplit(".", 2)[0] + ".cv1.bn")
ignore_bn_list.append(k + '.cv1.bn')
ignore_bn_list.append(k + '.cv2.bn')
if isinstance(m, nn.BatchNorm2d) and (k not in ignore_bn_list):
m.weight.grad.data.add_(srtmp * torch.sign(m.weight.data)) # L1
m.bias.grad.data.add_(opt.sr*10 * torch.sign(m.bias.data)) # L1
# # ============================= sparsity training ========================== #
1)for k, m in model.named_modules()
此处使用k,m两个变量是因为model.named_modules()在网络中的所有模块上返回一个迭代器,该迭代器不仅包含模块名称,还包含模块本身。因此在这段代码中k应该是模块名称,m应该是模块本身
2)不进行稀疏化的内容
图片来自知乎的博主
3)调节权重和bias
m.weight.grad.data.add_(srtmp * torch.sign(m.weight.data)) # L1
m.bias.grad.data.add_(opt.sr*10 * torch.sign(m.bias.data)) # L1
直接给出稀疏化训练的完整代码:
# YOLOv5 声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/运维做开发/article/detail/961812
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。