当前位置:   article > 正文

减少模型参数---模型剪枝(Pruning Deep Neural Networks)_模型剪枝原理

模型剪枝原理

简介

模型剪枝就是根据神经元的贡献程度对网络中的神经元进行排名,可以从网络中移除排名较低的神经元,从而形成一个更小、更快的网络模型。

基本思想示意图:

模型剪枝根据神经元权重的L1/L2范数来进行排序。剪枝后,准确率会下降,网络通常是训练-剪枝-训练-剪枝(trained-pruned-trained-pruned)迭代恢复的。如果我们一次剪枝太多,网络可能会被破坏得无法恢复。所以在实践中,这是一个迭代过程——通常被称为迭代剪枝(iterative pruning):剪枝/训练/重复。

训练时使用L1正则化能对参数进行稀疏作用

L1:稀疏与特征选择;L2:平滑特征

代码实现

预训练:

 原始网络模型需要满足 Conv2d+BatchNorm2d+ReLU 作为一个整体

训练时在BatchNorm层增加L1正则进行稀疏训练,得到每个特征图对应的gamma值,即γ越小,其对应的特征图越不重要,为了使得γ 能有特征选择的作用,引入L1正则来控制γ

  1. def updateBN(model, s):
  2. for m in model.modules():
  3. if isinstance(m, nn.BatchNorm2d):
  4. # L1 大于0为1 小于0为-1 0还是0
  5. m.weight.grad.data.add_(s*torch.sign(m.weight.data))
  6. '''
  7. '''
  8. #在训练函数中调用
  9. '''
  10. '''
  11. loss.backward()
  12. #剪枝优化
  13. sr = 0.0001
  14. if sr:
  15. updateBN(self.model,sr)
  16. self.optimizer.step()

 剪枝:

加载预训练模型,进行剪枝,然后保存剪枝后的模型

需要指定--percent 剪枝比例、--model 预训练的模型、--save 保存剪枝后的模型名称

  1. import os
  2. import argparse
  3. import torch
  4. import torch.nn as nn
  5. from torch.autograd import Variable
  6. from torchvision import datasets, transforms
  7. #from vgg import vgg
  8. from model.model import ASPNET
  9. import numpy as np
  10. # Prune settings
  11. parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')
  12. parser.add_argument('--dataset', type=str, default='cifar10',
  13. help='training dataset (default: cifar10)')
  14. parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
  15. help='input batch size for testing (default: 1000)')
  16. parser.add_argument('--no-cuda', action='store_true', default=False,
  17. help='disables CUDA training')
  18. parser.add_argument('--percent', type=float, default=0.5,
  19. help='scale sparse rate (default: 0.5)')
  20. parser.add_argument('--model', default='', type=str, metavar='PATH',
  21. help='path to raw trained model (default: none)')
  22. parser.add_argument('--save', default='', type=str, metavar='PATH',
  23. help='path to save prune model (default: none)')
  24. args = parser.parse_args()
  25. args.cuda = not args.no_cuda and torch.cuda.is_available()
  26. #model = vgg()
  27. model = ASPNET()
  28. if args.cuda:
  29. model.cuda()
  30. if args.model:
  31. if os.path.isfile(args.model):
  32. print("=> loading checkpoint '{}'".format(args.model))
  33. checkpoint = torch.load(args.model)
  34. args.start_epoch = checkpoint['epoch']
  35. best_prec1 = checkpoint['monitor_best']
  36. model.load_state_dict(checkpoint['state_dict'])
  37. print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
  38. .format(args.model, checkpoint['epoch'], best_prec1))
  39. else:
  40. print("=> no checkpoint found at '{}'".format(args.resume))
  41. print(model)
  42. total = 0 # 每层特征图个数 总和
  43. for m in model.modules():
  44. if isinstance(m, nn.BatchNorm2d):
  45. total += m.weight.data.shape[0]
  46. bn = torch.zeros(total) # 拿到每一个gamma值 每个特征图都会对应一个γ、β
  47. index = 0
  48. for m in model.modules():
  49. if isinstance(m, nn.BatchNorm2d):
  50. size = m.weight.data.shape[0]
  51. bn[index:(index+size)] = m.weight.data.abs().clone()
  52. index += size
  53. y, i = torch.sort(bn)
  54. thre_index = int(total * args.percent)
  55. thre = y[thre_index]
  56. pruned = 0
  57. cfg = []
  58. cfg_mask = []
  59. for k, m in enumerate(model.modules()):
  60. if isinstance(m, nn.BatchNorm2d):
  61. weight_copy = m.weight.data.clone()
  62. mask = weight_copy.abs().gt(thre).float().cuda() #.gt 比较前者是否大于后者
  63. pruned = pruned + mask.shape[0] - torch.sum(mask)
  64. m.weight.data.mul_(mask) # BN层gamma置0
  65. m.bias.data.mul_(mask) #
  66. cfg.append(int(torch.sum(mask)))
  67. cfg_mask.append(mask.clone())
  68. print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
  69. format(k, mask.shape[0], int(torch.sum(mask))))
  70. elif isinstance(m, nn.MaxPool2d):
  71. cfg.append('M')
  72. pruned_ratio = pruned/total
  73. print('Pre-processing Successful!')
  74. # 执行剪枝
  75. print(cfg)
  76. #newmodel = vgg(cfg=cfg) # 剪枝后的模型
  77. newmodel = ASPNET(net_name=cfg) # 剪枝后的模型
  78. newmodel.cuda()
  79. # 为剪枝后的模型赋值权重
  80. layer_id_in_cfg = 0
  81. start_mask = torch.ones(1) #输入
  82. end_mask = cfg_mask[layer_id_in_cfg] #输出
  83. for [m0, m1] in zip(model.modules(), newmodel.modules()):
  84. if isinstance(m0, nn.BatchNorm2d):
  85. idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) # 赋值
  86. m1.weight.data = m0.weight.data[idx1].clone()
  87. m1.bias.data = m0.bias.data[idx1].clone()
  88. m1.running_mean = m0.running_mean[idx1].clone()
  89. m1.running_var = m0.running_var[idx1].clone()
  90. layer_id_in_cfg += 1
  91. start_mask = end_mask.clone() #下一层的
  92. if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
  93. end_mask = cfg_mask[layer_id_in_cfg] #输出
  94. elif isinstance(m0, nn.Conv2d):
  95. idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
  96. idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
  97. print(idx0)
  98. print(idx1)
  99. if idx0.size == 1:
  100. idx0 = np.resize(idx0, (1,))
  101. if idx1.size == 1:
  102. idx1 = np.resize(idx1, (1,))
  103. #print('In shape: {:d} Out shape:{:d}'.format(idx0.shape[0], idx1.shape[0]))
  104. w = m0.weight.data[:, idx0, :, :].clone() #拿到原始训练好权重
  105. w = w[idx1, :, :, :].clone()
  106. m1.weight.data = w.clone() # 将所需权重赋值到剪枝后的模型
  107. # m1.bias.data = m0.bias.data[idx1].clone()
  108. elif isinstance(m0, nn.Linear):
  109. #idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
  110. #m1.weight.data = m0.weight.data[:, idx0].clone()
  111. m1.weight.data = m0.weight.data.clone()
  112. torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, args.save)
  113. print(newmodel)

剪枝前参数大小  27102,剪枝后参数大小 7185

使用剪枝后的模型再训练:

使用剪枝后的网络架构,同时加载剪枝后的模型参数进行初始化

  1. refine = 剪枝后的模型
  2. if refine:
  3. checkpoint = torch.load(refine)
  4. print(checkpoint['cfg'])
  5. model = ASPNET(net_name=checkpoint['cfg'])#使用剪枝后的网络架构
  6. model.cuda()
  7. model.load_state_dict(checkpoint['state_dict'])

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/190747
推荐阅读
相关标签
  

闽ICP备14008679号