当前位置:   article > 正文

参数量降低之剪枝_m.weight.grad.data.add_(reg*torch.sign(m.weight.da

m.weight.grad.data.add_(reg*torch.sign(m.weight.data))

一、论文

https://openaccess.thecvf.com/content_ICCV_2017/papers/Liu_Learning_Efficient_Convolutional_ICCV_2017_paper.pdf

二、参考代码

GitHub - foolwood/pytorch-slimming: Learning Efficient Convolutional Networks through Network Slimming, In ICCV 2017.

三、剪枝流程(训练->剪枝->在训练)

1、网络结构图

1.1方法 

通过BN归一化里面的γ缩放系数 +稀疏化的L1范数,可以理解为 通过γ系数得到特征图比重大小,然后加上L1范数,进行稀疏化,把重要的值放大,不重要的值弄小

(原因:通过卷积层后是线性相关的,分布杂乱,使用BN归一化后把哪些偏离的离谱的分布给弄到均值为0方差为1的标准分布,这样训练的会很快,但是在BN后,感觉把数值分布强制在了非线性函数的线性区域中。于是用到了BN里面的另外两个参数γ和β,把数值进行缩放和偏移放在非线性区域)

1.2 L1的稀疏化

从图中可以看出L1的导数即梯度 是(-1,1)有稀疏性质

L2的导数即梯度类似一个线性函数

三、常用剪枝方法

四、代码实现 

第一种思路方法:

训练的时候使用L1权重约束项给BN归一化的weight具有稀疏性,

 剪枝的时候:获取所有的BN归一化里面的weight列表,使其进行排序,获取保留weight里面的最后一个值作为阈值。 然后对每一个BN归一化的weight进行mask(值为1表示需要剪枝保留的,值为0表示剪枝不需要的)然后把BN归一化需要的层数保留下来,不需要的丢弃,重新定义一个网络进行训练

1、训练模型(记住给BN里面的权重加上L1约束项,使其有稀疏性)

  1. s = 0.0001
  2. def updateBN():
  3. for m in model.modules():
  4. if isinstance(m, nn.BatchNorm2d):
  5. m.weight.grad.data.add_(s*torch.sign(m.weight.data)) # L1 梯度 大于0的为1,小于0的为-1

2、剪枝模块

  1. percent = 0.9 # 全选择前百分之多少有用的特征图
  2. total = 0 # 记录所有的BN层的特征图的个数
  3. for m in model.modules():
  4. if isinstance(m, nn.BatchNorm2d):
  5. total += m.weight.data.shape[0]
  6. bn = torch.zeros(total) # 创建所有BN层的特征图的γ分值的存储空间
  7. index = 0
  8. for m in model.modules():
  9. if isinstance(m, nn.BatchNorm2d):
  10. size = m.weight.data.shape[0]
  11. bn[index:(index+size)] = m.weight.data.abs().clone()
  12. index += size
  13. y, i = torch.sort(bn) # 给bn里面的γ进行排序
  14. percent = 0.9 # 全选择前百分之多少有用的特征图
  15. thre_index = int(total * percent)
  16. thre = y[thre_index] # 找到最后一个γ的值
  17. pruned = 0
  18. cfg = [] # 记录保留的特征图的个数
  19. cfg_mask = []
  20. for k, m in enumerate(model.modules()):
  21. if isinstance(m, nn.BatchNorm2d):
  22. weight_copy = m.weight.data.clone()
  23. mask = weight_copy.abs().gt(thre).float().cuda() # .gt(thre) 指的是实际的值大于
  24. # thre的值,返回list,里面是0或者1
  25. pruned = pruned + mask.shape[0] - torch.sum(mask) # 用于记录剪枝剪了多少层
  26. m.weight.data.mul_(mask)
  27. m.bias.data.mul_(mask)
  28. cfg.append(int(torch.sum(mask))) # 记录保留的特征图的个数
  29. cfg_mask.append(mask.clone()) # 所有的特征图
  30. print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
  31. format(k, mask.shape[0], int(torch.sum(mask))))
  32. elif isinstance(m, nn.MaxPool2d):
  33. cfg.append('M')
  34. pruned_ratio = pruned/total # 剪枝比例,指得是剪了多少
  35. print('Pre-processing Successful!')
  36. print(cfg)
  37. """
  38. 构建新的网络模型后把 一开始大模型的权重值给新模型做一个初始化,方法 根据索引把每层对应位置的权重筛选出来然后赋值给新的模型
  39. """
  40. layer_id_in_cfg = 0 # 为剪枝后的模型复制权重
  41. start_mask = torch.ones(3) # 输入
  42. end_mask = cfg_mask[layer_id_in_cfg] # 输出
  43. for [m0, m1] in zip(model.modules(), newmodel.modules()):
  44. if isinstance(m0, nn.BatchNorm2d):
  45. idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) # 把>阈值的特
  46. # 征图的索引筛选出来
  47. m1.weight.data = m0.weight.data[idx1].clone()
  48. m1.bias.data = m0.bias.data[idx1].clone()
  49. m1.running_mean = m0.running_mean[idx1].clone()
  50. m1.running_var = m0.running_var[idx1].clone()
  51. layer_id_in_cfg += 1
  52. start_mask = end_mask.clone()
  53. if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
  54. end_mask = cfg_mask[layer_id_in_cfg]
  55. elif isinstance(m0, nn.Conv2d):
  56. idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
  57. idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
  58. print('In shape: {:d} Out shape:{:d}'.format(idx0.shape[0], idx1.shape[0]))
  59. w = m0.weight.data[:, idx0, :, :].clone()
  60. w = w[idx1, :, :, :].clone()
  61. m1.weight.data = w.clone()
  62. # m1.bias.data = m0.bias.data[idx1].clone()
  63. elif isinstance(m0, nn.Linear):
  64. idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
  65. m1.weight.data = m0.weight.data[:, idx0].clone()
  66. torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, args.save)

3、把剪枝后的模型进行微调训练

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/190773
推荐阅读
相关标签
  

闽ICP备14008679号