当前位置:   article > 正文

模型剪枝初级方法_halcon中的pruned model

halcon中的pruned model

信息源:https://www.bilibili.com/video/BV147411W7am/?spm_id_from=333.788.recommend_more_video.2&vd_source=3969f30b089463e19db0cc5e8fe4583a

1、剪枝的含义

把不重要的参数去掉,计算就更快了,模型的大小就变小了(本文涉及的剪枝方式没有这个功能)。

2、全连接层的剪枝

上述剪枝就是把一些weight置为0,这样计算就更快了。

计算掩码矩阵的过程:

 接下来要做的:

(1)给每一层增加一个变量,用于存储mask

(2)设计一个函数,用于计算mask

3、卷积层剪枝

 假如有4个卷积核,计算每个卷积核的L2范数,哪个卷积核的范数值最小则对应的mask全部置为0.如上图灰色的部分。

4、代码部分

GitHub - mepeichun/Efficient-Neural-Network-Bilibili: B站Efficient-Neural-Network学习分享的配套代码

5、全连接层剪枝

(1)剪枝思路

假设剪枝的比例为50%。

找到每一个linear的layer,然后取参数的50%分位数,接着构造mask,所有大于50%分位数的mask位置置为1,所有小于等于50%分位数的mask位置置为0。

最后weight * mask得到新的weight。

(2)剪枝代码

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torchvision import datasets, transforms
  5. import torch.utils.data
  6. import numpy as np
  7. import math
  8. from copy import deepcopy
  9. def to_var(x, requires_grad=False):
  10. if torch.cuda.is_available():
  11. x = x.cuda()
  12. return x.clone().detach().requires_grad_(requires_grad)
  13. class MaskedLinear(nn.Linear):
  14. def __init__(self, in_features, out_features, bias=True):
  15. super(MaskedLinear, self).__init__(in_features, out_features, bias)
  16. self.mask_flag = False
  17. self.mask = None
  18. def set_mask(self, mask):
  19. self.mask = to_var(mask, requires_grad=False)
  20. self.weight.data = self.weight.data * self.mask.data
  21. self.mask_flag = True
  22. def get_mask(self):
  23. print(self.mask_flag)
  24. return self.mask
  25. def forward(self, x):
  26. # 以下代码与set_mask中的self.weight.data = self.weight.data * self.mask.data重复了
  27. # if self.mask_flag:
  28. # weight = self.weight * self.mask
  29. # return F.linear(x, weight, self.bias)
  30. # else:
  31. # return F.linear(x, self.weight, self.bias)
  32. return F.linear(x, self.weight, self.bias)
  33. class MLP(nn.Module):
  34. def __init__(self):
  35. super(MLP, self).__init__()
  36. self.linear1 = MaskedLinear(28*28, 200)
  37. self.relu1 = nn.ReLU(inplace=True)
  38. self.linear2 = MaskedLinear(200, 200)
  39. self.relu2 = nn.ReLU(inplace=True)
  40. self.linear3 = MaskedLinear(200, 10)
  41. def forward(self, x):
  42. out = x.view(x.size(0), -1)
  43. out = self.relu1(self.linear1(out))
  44. out = self.relu2(self.linear2(out))
  45. out = self.linear3(out)
  46. return out
  47. def set_masks(self, masks):
  48. self.linear1.set_mask(masks[0])
  49. self.linear2.set_mask(masks[1])
  50. self.linear3.set_mask(masks[2])
  51. def train(model, device, train_loader, optimizer, epoch):
  52. model.train()
  53. total = 0
  54. for batch_idx, (data, target) in enumerate(train_loader):
  55. data, target = data.to(device), target.to(device)
  56. optimizer.zero_grad()
  57. output = model(data)
  58. loss = F.cross_entropy(output, target)
  59. loss.backward()
  60. optimizer.step()
  61. total += len(data)
  62. progress = math.ceil(batch_idx / len(train_loader) * 50)
  63. print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
  64. (epoch, total, len(train_loader.dataset),
  65. '-' * progress + '>', progress * 2), end='')
  66. def test(model, device, test_loader):
  67. model.eval()
  68. test_loss = 0
  69. correct = 0
  70. with torch.no_grad():
  71. for data, target in test_loader:
  72. data, target = data.to(device), target.to(device)
  73. output = model(data)
  74. test_loss += F.cross_entropy(output, target, reduction='sum').item()
  75. pred = output.argmax(dim=1, keepdim=True)
  76. correct += pred.eq(target.view_as(pred)).sum().item()
  77. test_loss /= len(test_loader.dataset)
  78. print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
  79. test_loss, correct, len(test_loader.dataset),
  80. 100. * correct / len(test_loader.dataset)))
  81. return test_loss, correct / len(test_loader.dataset)
  82. def weight_prune(model, pruning_perc):
  83. threshold_list = []
  84. for p in model.parameters():
  85. if len(p.data.size()) != 1: # bias
  86. weight = p.cpu().data.abs().numpy().flatten()
  87. threshold = np.percentile(weight, pruning_perc)
  88. threshold_list.append(threshold)
  89. # generate mask
  90. masks = []
  91. idx = 0
  92. for p in model.parameters():
  93. if len(p.data.size()) != 1:
  94. pruned_inds = p.data.abs() > threshold_list[idx]
  95. masks.append(pruned_inds.float())
  96. idx += 1
  97. return masks
  98. def main():
  99. epochs = 2
  100. batch_size = 64
  101. torch.manual_seed(0)
  102. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  103. train_loader = torch.utils.data.DataLoader(
  104. datasets.MNIST('D:/ai_data/mnist_dataset', train=True, download=False,
  105. transform=transforms.Compose([
  106. transforms.ToTensor(),
  107. transforms.Normalize((0.1307,), (0.3081,))
  108. ])),
  109. batch_size=batch_size, shuffle=True)
  110. test_loader = torch.utils.data.DataLoader(
  111. datasets.MNIST('D:/ai_data/mnist_dataset', train=False, download=False, transform=transforms.Compose([
  112. transforms.ToTensor(),
  113. transforms.Normalize((0.1307,), (0.3081,))
  114. ])),
  115. batch_size=1000, shuffle=True)
  116. model = MLP().to(device)
  117. optimizer = torch.optim.Adadelta(model.parameters())
  118. for epoch in range(1, epochs + 1):
  119. train(model, device, train_loader, optimizer, epoch)
  120. _, acc = test(model, device, test_loader)
  121. print("\n=====Pruning 60%=======\n")
  122. pruned_model = deepcopy(model)
  123. mask = weight_prune(pruned_model, 60)
  124. pruned_model.set_masks(mask)
  125. test(pruned_model, device, test_loader)
  126. return model, pruned_model
  127. model, pruned_model = main()
  128. torch.save(model.state_dict(), ".model.pth")
  129. torch.save(pruned_model.state_dict(), ".pruned_model.pth")
  130. from matplotlib import pyplot as plt
  131. def plot_weights(model):
  132. modules = [module for module in model.modules()]
  133. num_sub_plot = 0
  134. for i, layer in enumerate(modules):
  135. if hasattr(layer, 'weight'):
  136. plt.subplot(131+num_sub_plot)
  137. w = layer.weight.data
  138. w_one_dim = w.cpu().numpy().flatten()
  139. plt.hist(w_one_dim[w_one_dim != 0], bins=50)
  140. num_sub_plot += 1
  141. plt.show()
  142. model = MLP()
  143. pruned_model = MLP()
  144. model.load_state_dict(torch.load('.model.pth'))
  145. pruned_model.load_state_dict(torch.load('.pruned_model.pth'))
  146. plot_weights(model)
  147. plot_weights(pruned_model)

(3)剪枝前后精确度信息

Train epoch 1: 60000/60000, [-------------------------------------------------->]

100%

Test: average loss: 0.1391, accuracy: 9562/10000 (96%)

Train epoch 2: 60000/60000, [-------------------------------------------------->]

100%

Test: average loss: 0.0870, accuracy: 9741/10000 (97%)  

=====Pruning 60%=======

Test: average loss: 0.0977, accuracy: 9719/10000 (97%)

通过数据,可以发现剪枝前后准确率并未下降太多。

(4)剪枝前后模型参数数据分布

剪枝前的分布:

剪枝后的分布:

6、卷积层剪枝

(1)剪枝思路

假设剪枝的比例为50%。

  • 对于每一个layer的cnn卷积层,计算其参数的L2范数值,
  • 然后将数值通过sum()操作聚合到channel维度上,接着将该值在channel维度上归一化,取非零值中的最小值和对应的channel索引值。
  • 多个layer比较各自的最小值,取最小的值及对应的channel索引值对应的mask置为0
  • 计算所有参数中零值的比例,一直重复以上3步直到零值的比例达到剪枝的比例。

每一个layer的weight * mask就得到了新的weight。

(2)剪枝代码

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torchvision import datasets, transforms
  5. import torch.utils.data
  6. import numpy as np
  7. import math
  8. def to_var(x, requires_grad=False):
  9. if torch.cuda.is_available():
  10. x = x.cuda()
  11. return x.clone().detach().requires_grad_(requires_grad)
  12. class MaskedConv2d(nn.Conv2d):
  13. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
  14. super(MaskedConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
  15. self.mask_flag = False
  16. def set_mask(self, mask):
  17. self.mask = to_var(mask, requires_grad=False)
  18. self.weight.data = self.weight.data * self.mask.data
  19. self.mask_flag = True
  20. def get_mask(self):
  21. print(self.mask_flag)
  22. return self.mask
  23. def forward(self, x):
  24. # 以下部分与set_mask的self.weight.data = self.weight.data * self.mask.data重合
  25. # if self.mask_flag == True:
  26. # weight = self.weight * self.mask
  27. # return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
  28. # else:
  29. # return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
  30. return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
  31. class ConvNet(nn.Module):
  32. def __init__(self):
  33. super(ConvNet, self).__init__()
  34. self.conv1 = MaskedConv2d(1, 32, kernel_size=3, padding=1, stride=1)
  35. self.relu1 = nn.ReLU(inplace=True)
  36. self.maxpool1 = nn.MaxPool2d(2)
  37. self.conv2 = MaskedConv2d(32, 64, kernel_size=3, padding=1, stride=1)
  38. self.relu2 = nn.ReLU(inplace=True)
  39. self.maxpool2 = nn.MaxPool2d(2)
  40. self.conv3 = MaskedConv2d(64, 64, kernel_size=3, padding=1, stride=1)
  41. self.relu3 = nn.ReLU(inplace=True)
  42. self.linear1 = nn.Linear(7*7*64, 10)
  43. def forward(self, x):
  44. out = self.maxpool1(self.relu1(self.conv1(x)))
  45. out = self.maxpool2(self.relu2(self.conv2(out)))
  46. out = self.relu3(self.conv3(out))
  47. out = out.view(out.size(0), -1)
  48. out = self.linear1(out)
  49. return out
  50. def set_masks(self, masks):
  51. self.conv1.set_mask(torch.from_numpy(masks[0]))
  52. self.conv2.set_mask(torch.from_numpy(masks[1]))
  53. self.conv3.set_mask(torch.from_numpy(masks[2]))
  54. def train(model, device, train_loader, optimizer, epoch):
  55. model.train()
  56. total = 0
  57. for batch_idx, (data, target) in enumerate(train_loader):
  58. data, target = data.to(device), target.to(device)
  59. optimizer.zero_grad()
  60. output = model(data)
  61. loss = F.cross_entropy(output, target)
  62. loss.backward()
  63. optimizer.step()
  64. total += len(data)
  65. progress = math.ceil(batch_idx / len(train_loader) * 50)
  66. print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
  67. (epoch, total, len(train_loader.dataset),
  68. '-' * progress + '>', progress * 2), end='')
  69. def test(model, device, test_loader):
  70. model.eval()
  71. test_loss = 0
  72. correct = 0
  73. with torch.no_grad():
  74. for data, target in test_loader:
  75. data, target = data.to(device), target.to(device)
  76. output = model(data)
  77. test_loss += F.cross_entropy(output, target, reduction='sum').item()
  78. pred = output.argmax(dim=1, keepdim=True)
  79. correct += pred.eq(target.view_as(pred)).sum().item()
  80. test_loss /= len(test_loader.dataset)
  81. print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
  82. test_loss, correct, len(test_loader.dataset),
  83. 100. * correct / len(test_loader.dataset)))
  84. return test_loss, correct / len(test_loader.dataset)
  85. def prune_rate(model, verbose=False):
  86. """
  87. 计算模型的裁剪比例
  88. :param model:
  89. :param verbose:
  90. :return:
  91. """
  92. total_nb_param = 0
  93. nb_zero_param = 0
  94. layer_id = 0
  95. for parameter in model.parameters():
  96. param_this_layer = 1
  97. for dim in parameter.data.size():
  98. param_this_layer *= dim
  99. total_nb_param += param_this_layer
  100. # only pruning linear and conv layers
  101. if len(parameter.data.size()) != 1:
  102. layer_id += 1
  103. zero_param_this_layer = np.count_nonzero(parameter.cpu().data.numpy() == 0)
  104. nb_zero_param += zero_param_this_layer
  105. if verbose:
  106. print("Layer {} | {} layer | {:.2f}% parameters pruned" \
  107. .format(
  108. layer_id,
  109. 'Conv' if len(parameter.data.size()) == 4 \
  110. else 'Linear',
  111. 100. * zero_param_this_layer / param_this_layer,
  112. ))
  113. pruning_perc = 100. * nb_zero_param / total_nb_param
  114. if verbose:
  115. print("Final pruning rate: {:.2f}%".format(pruning_perc))
  116. return pruning_perc
  117. def arg_nonzero_min(a):
  118. """
  119. 获取非零值中的最小值及其下标值
  120. :param a:
  121. :return:
  122. """
  123. if not a:
  124. return
  125. min_ix, min_v = None, None
  126. # 查看是否所有值都为0
  127. for i, e in enumerate(a):
  128. if e != 0:
  129. min_ix = i
  130. min_v = e
  131. break
  132. if min_ix is None:
  133. print('Warning: all zero')
  134. return np.inf, np.inf
  135. # search for the smallest nonzero
  136. for i, e in enumerate(a):
  137. if e < min_v and e != 0:
  138. min_v = e
  139. min_ix = i
  140. return min_v, min_ix
  141. def prune_one_filter(model, masks):
  142. """
  143. pruning one least import feature map by the scaled l2norm of kernel weights
  144. 用缩放的核权重l2范数修剪最小输入特征图
  145. :param model:
  146. :param masks:
  147. :return:
  148. """
  149. NO_MASKS = False
  150. # construct masks if there is not yet
  151. if not masks:
  152. masks = []
  153. NO_MASKS = True
  154. values = []
  155. for p in model.parameters():
  156. if len(p.data.size()) == 4:
  157. p_np = p.data.cpu().numpy()
  158. # construct masks if there is not
  159. if NO_MASKS:
  160. masks.append(np.ones(p_np.shape).astype('float32'))
  161. # find the scaled l2 norm for each filter this layer
  162. value_this_layer = np.square(p_np).sum(axis=1).sum(axis=1).sum(axis=1) / (p_np.shape[1] * p_np.shape[2] * p_np.shape[3])
  163. # normalization(important)
  164. value_this_layer = value_this_layer / np.sqrt(np.square(value_this_layer).sum())
  165. min_value, min_ind = arg_nonzero_min(list(value_this_layer))
  166. values.append([min_value, min_ind])
  167. assert len(masks) == len(values), "something wrong here"
  168. values = np.array(values) # [[min_value, min_ind], [min_value, min_ind], [min_value, min_ind]]
  169. # set mask corresponding to the filter to prune
  170. to_prune_layer_ind = np.argmin(values[:, 0])
  171. to_prune_filter_ind = int(values[to_prune_layer_ind, 1])
  172. masks[to_prune_layer_ind][to_prune_filter_ind] = 0.
  173. return masks
  174. def filter_prune(model, pruning_perc):
  175. """
  176. 剪枝主流程,不停剪枝直到裁剪比例达到要求
  177. :param model:
  178. :param pruning_perc:
  179. :return:
  180. """
  181. masks = []
  182. current_pruning_perc = 0
  183. while current_pruning_perc < pruning_perc:
  184. masks = prune_one_filter(model, masks)
  185. model.set_masks(masks)
  186. current_pruning_perc = prune_rate(model, verbose=False)
  187. print('{:.2f} pruned'.format(current_pruning_perc))
  188. return masks
  189. def main():
  190. epochs = 2
  191. batch_size = 64
  192. torch.manual_seed(0)
  193. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  194. train_loader = torch.utils.data.DataLoader(
  195. datasets.MNIST('D:/ai_data/mnist_dataset', train=True, download=False,
  196. transform=transforms.Compose([
  197. transforms.ToTensor(),
  198. transforms.Normalize((0.1307,), (0.3081,))
  199. ])),
  200. batch_size=batch_size, shuffle=True)
  201. test_loader = torch.utils.data.DataLoader(
  202. datasets.MNIST('D:/ai_data/mnist_dataset', train=False, download=False, transform=transforms.Compose([
  203. transforms.ToTensor(),
  204. transforms.Normalize((0.1307,), (0.3081,))
  205. ])),
  206. batch_size=1000, shuffle=True)
  207. model = ConvNet().to(device)
  208. optimizer = torch.optim.Adadelta(model.parameters())
  209. for epoch in range(1, epochs + 1):
  210. train(model, device, train_loader, optimizer, epoch)
  211. _, acc = test(model, device, test_loader)
  212. print('\npruning 50%')
  213. mask = filter_prune(model, 50)
  214. model.set_masks(mask)
  215. _, acc = test(model, device, test_loader)
  216. # finetune
  217. print('\nfinetune')
  218. train(model, device, train_loader, optimizer, epoch)
  219. _, acc = test(model, device, test_loader)
  220. main()

 (3)精确度及剪枝比例信息:

  1. Train epoch 1: 60000/60000, [-------------------------------------------------->] 100%
  2. Test: average loss: 0.0505, accuracy: 9833/10000 (98%)
  3. Train epoch 2: 60000/60000, [-------------------------------------------------->] 100%
  4. Test: average loss: 0.0311, accuracy: 9893/10000 (99%)
  5. pruning 50%
  6. 0.66 pruned
  7. 1.32 pruned
  8. 1.65 pruned
  9. 1.98 pruned
  10. 2.31 pruned
  11. 2.64 pruned
  12. 2.98 pruned
  13. 3.64 pruned
  14. 3.97 pruned
  15. 4.63 pruned
  16. 4.64 pruned
  17. 4.65 pruned
  18. 4.98 pruned
  19. 5.31 pruned
  20. 5.32 pruned
  21. 5.65 pruned
  22. 6.31 pruned
  23. 6.97 pruned
  24. 7.30 pruned
  25. 7.63 pruned
  26. 8.30 pruned
  27. 8.31 pruned
  28. 8.97 pruned
  29. 9.30 pruned
  30. 9.96 pruned
  31. 10.29 pruned
  32. 10.95 pruned
  33. 11.61 pruned
  34. 11.94 pruned
  35. 12.60 pruned
  36. 13.27 pruned
  37. 13.93 pruned
  38. 14.26 pruned
  39. 14.92 pruned
  40. 15.25 pruned
  41. 15.26 pruned
  42. 15.59 pruned
  43. 16.25 pruned
  44. 16.91 pruned
  45. 17.57 pruned
  46. 17.90 pruned
  47. 18.23 pruned
  48. 18.90 pruned
  49. 19.56 pruned
  50. 19.89 pruned
  51. 20.55 pruned
  52. 20.88 pruned
  53. 21.54 pruned
  54. 21.87 pruned
  55. 21.88 pruned
  56. 22.54 pruned
  57. 22.87 pruned
  58. 23.53 pruned
  59. 24.20 pruned
  60. 24.21 pruned
  61. 24.87 pruned
  62. 25.20 pruned
  63. 25.86 pruned
  64. 26.19 pruned
  65. 26.20 pruned
  66. 26.86 pruned
  67. 27.19 pruned
  68. 27.52 pruned
  69. 28.18 pruned
  70. 28.51 pruned
  71. 29.18 pruned
  72. 29.51 pruned
  73. 29.52 pruned
  74. 29.85 pruned
  75. 29.86 pruned
  76. 30.52 pruned
  77. 30.85 pruned
  78. 31.51 pruned
  79. 32.17 pruned
  80. 32.83 pruned
  81. 33.16 pruned
  82. 33.82 pruned
  83. 34.16 pruned
  84. 34.82 pruned
  85. 35.15 pruned
  86. 35.48 pruned
  87. 36.14 pruned
  88. 36.47 pruned
  89. 37.13 pruned
  90. 37.79 pruned
  91. 37.80 pruned
  92. 38.13 pruned
  93. 38.79 pruned
  94. 38.80 pruned
  95. 39.13 pruned
  96. 39.15 pruned
  97. 39.81 pruned
  98. 40.14 pruned
  99. 40.47 pruned
  100. 40.48 pruned
  101. 41.14 pruned
  102. 41.47 pruned
  103. 41.80 pruned
  104. 41.81 pruned
  105. 42.47 pruned
  106. 43.13 pruned
  107. 43.46 pruned
  108. 43.79 pruned
  109. 44.46 pruned
  110. 44.79 pruned
  111. 44.80 pruned
  112. 45.46 pruned
  113. 45.79 pruned
  114. 45.80 pruned
  115. 46.46 pruned
  116. 46.79 pruned
  117. 47.12 pruned
  118. 47.78 pruned
  119. 47.79 pruned
  120. 47.80 pruned
  121. 48.13 pruned
  122. 48.79 pruned
  123. 49.13 pruned
  124. 49.79 pruned
  125. 49.80 pruned
  126. 50.46 pruned
  127. Test: average loss: 1.6824, accuracy: 6513/10000 (65%)
  128. finetune
  129. Train epoch 2: 60000/60000, [-------------------------------------------------->] 100%
  130. Test: average loss: 0.0324, accuracy: 9889/10000 (99%)

可以看到,剪枝完成后直接测试准确率只有65%非常低,重新对weight中的非零参数训练一次后立马接近之前的准确率。

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号