当前位置:   article > 正文

pytorch自带的模型剪枝工具prune的使用_torch.nn.utils.prune

torch.nn.utils.prune

torch.nn.utils.prune可以对模型进行剪枝,官方指导如下:

https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

直接上代码

首先建立模型网络:

  1. import torch
  2. import torch.nn as nn
  3. from torchsummary import summary
  4. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  5. class SimpleNet(nn.Module):
  6. def __init__(self, num_classes=10):
  7. super(SimpleNet, self).__init__()
  8. self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=1, padding=1)
  9. self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=3, stride=1, padding=1)
  10. self.pool = nn.MaxPool2d(kernel_size=2)
  11. self.conv3 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, stride=1, padding=1)
  12. self.conv4 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=3, stride=1, padding=1)
  13. self.fc = nn.Linear(in_features=16 * 16 * 24, out_features=num_classes)
  14. def forward(self, input):
  15. output = self.conv1(input)
  16. output = nn.ReLU()(output)
  17. output = self.conv2(output)
  18. output = nn.ReLU()(output)
  19. output = self.pool(output)
  20. output = self.conv3(output)
  21. output = nn.ReLU()(output)
  22. output = self.conv4(output)
  23. output = nn.ReLU()(output)
  24. output = output.view(-1, 16 * 16 * 24)
  25. output = self.fc(output)
  26. return output
  27. model = SimpleNet().to(device=device)

看一下模型的 summary

summary(model, input_size=(3, 512, 512))

  1. ----------------------------------------------------------------
  2. Layer (type) Output Shape Param #
  3. ================================================================
  4. Conv2d-1 [-1, 12, 512, 512] 336
  5. Conv2d-2 [-1, 12, 512, 512] 1,308
  6. MaxPool2d-3 [-1, 12, 256, 256] 0
  7. Conv2d-4 [-1, 24, 256, 256] 2,616
  8. Conv2d-5 [-1, 24, 256, 256] 5,208
  9. Linear-6 [-1, 10] 61,450
  10. ================================================================
  11. Total params: 70,918
  12. Trainable params: 70,918
  13. Non-trainable params: 0
  14. ----------------------------------------------------------------
  15. Input size (MB): 3.00
  16. Forward/backward pass size (MB): 78.00
  17. Params size (MB): 0.27
  18. Estimated Total Size (MB): 81.27
  19. ----------------------------------------------------------------

打印一下模型结构各层的名称:

print(model.state_dict().keys())

结果:

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias', 'conv4.weight', 'conv4.bias', 'fc.weight', 'fc.bias'])

接下来 对其进行剪枝操作:

  1. import torch.nn.utils.prune as prune
  2. parameters_to_prune = (
  3. (model.conv1, 'weight'),
  4. (model.conv2, 'weight'),
  5. (model.conv4, 'weight'),
  6. (model.fc, 'weight'),
  7. )
  8. prune.global_unstructured(
  9. parameters_to_prune,
  10. pruning_method=prune.L1Unstructured,
  11. amount=0.2,
  12. )

执行结束后,再打印一下:

print(model.state_dict().keys())

结果:

odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.bias', 'conv2.weight_orig', 'conv2.weight_mask', 'conv3.weight', 'conv3.bias', 'conv4.bias', 'conv4.weight_orig', 'conv4.weight_mask', 'fc.bias', 'fc.weight_orig', 'fc.weight_mask'])

我们发现剪枝结束后conv*.weight已经 消失了,出现了两个weight:weight_orig和weight_mask

其实weight_orig就是剪枝以前的weight,而weight_mask里面 只是0和1,0代表的是被剪枝的

打印一下:

print(model.state_dict()['conv1.weight_orig'])

  1. tensor([[[[1., 1., 1.],
  2. [1., 1., 1.],
  3. [0., 1., 1.]],
  4. [[1., 1., 1.],
  5. [1., 1., 1.],
  6. [1., 1., 1.]],
  7. [[1., 1., 1.],
  8. [1., 1., 1.],
  9. [1., 1., 1.]]],
  10. [[[0., 1., 1.],
  11. [1., 1., 1.],
  12. [1., 1., 1.]],
  13. [[1., 1., 1.],
  14. [1., 1., 1.],
  15. [1., 1., 1.]],
  16. [[1., 1., 1.],
  17. [1., 1., 1.],
  18. [1., 1., 1.]]],
  19. [[[1., 1., 1.],
  20. [1., 1., 1.],
  21. [1., 1., 1.]],
  22. [[1., 1., 1.],
  23. [1., 1., 1.],
  24. [1., 1., 1.]],
  25. [[1., 1., 1.],
  26. [1., 1., 1.],
  27. [1., 1., 1.]]],
  28. [[[1., 1., 1.],
  29. [1., 1., 1.],
  30. [1., 1., 1.]],
  31. [[1., 1., 1.],
  32. [1., 1., 1.],
  33. [1., 1., 1.]],
  34. [[1., 1., 1.],
  35. [1., 1., 1.],
  36. [1., 1., 1.]]],
  37. [[[1., 1., 1.],
  38. [1., 1., 1.],
  39. [1., 1., 1.]],
  40. [[1., 1., 1.],
  41. [1., 1., 1.],
  42. [1., 1., 1.]],
  43. [[1., 1., 1.],
  44. [1., 1., 1.],
  45. [1., 1., 1.]]],
  46. [[[1., 1., 1.],
  47. [1., 1., 1.],
  48. [1., 1., 1.]],
  49. [[1., 1., 1.],
  50. [1., 1., 1.],
  51. [1., 1., 0.]],
  52. [[1., 1., 1.],
  53. [1., 1., 1.],
  54. [1., 1., 1.]]],
  55. [[[1., 1., 1.],
  56. [1., 1., 1.],
  57. [1., 1., 1.]],
  58. [[1., 1., 1.],
  59. [1., 1., 1.],
  60. [1., 1., 1.]],
  61. [[1., 1., 1.],
  62. [1., 1., 1.],
  63. [1., 1., 1.]]],
  64. [[[1., 1., 1.],
  65. [1., 1., 1.],
  66. [1., 1., 1.]],
  67. [[1., 1., 1.],
  68. [1., 1., 1.],
  69. [1., 1., 0.]],
  70. [[1., 1., 1.],
  71. [1., 1., 1.],
  72. [1., 1., 1.]]],
  73. [[[1., 1., 1.],
  74. [1., 1., 1.],
  75. [1., 1., 0.]],
  76. [[1., 1., 1.],
  77. [1., 1., 1.],
  78. [1., 1., 1.]],
  79. [[1., 1., 1.],
  80. [1., 1., 1.],
  81. [1., 1., 1.]]],
  82. [[[1., 1., 1.],
  83. [1., 1., 1.],
  84. [1., 1., 1.]],
  85. [[1., 1., 1.],
  86. [1., 1., 0.],
  87. [1., 1., 1.]],
  88. [[1., 1., 1.],
  89. [1., 1., 1.],
  90. [1., 1., 1.]]],
  91. [[[1., 1., 1.],
  92. [1., 1., 1.],
  93. [1., 1., 1.]],
  94. [[1., 1., 1.],
  95. [1., 1., 1.],
  96. [1., 1., 1.]],
  97. [[1., 1., 1.],
  98. [1., 1., 1.],
  99. [1., 1., 1.]]],
  100. [[[1., 1., 1.],
  101. [1., 1., 1.],
  102. [1., 1., 1.]],
  103. [[1., 1., 1.],
  104. [1., 1., 1.],
  105. [1., 1., 1.]],
  106. [[1., 1., 1.],
  107. [1., 1., 1.],
  108. [1., 1., 1.]]]], device='cuda:0')
  109. prune.remove(module,

剪枝后,其实还是比较鸡肋的,因为只是剪之后的神经元相当于置零了,模型大小不会变,下面打印一下,有点dropout的意思了

  1. ----------------------------------------------------------------
  2. Layer (type) Output Shape Param #
  3. ================================================================
  4. Conv2d-1 [-1, 12, 512, 512] 336
  5. Conv2d-2 [-1, 12, 512, 512] 1,308
  6. MaxPool2d-3 [-1, 12, 256, 256] 0
  7. Conv2d-4 [-1, 24, 256, 256] 2,616
  8. Conv2d-5 [-1, 24, 256, 256] 5,208
  9. Linear-6 [-1, 10] 61,450
  10. ================================================================
  11. Total params: 70,918
  12. Trainable params: 70,918
  13. Non-trainable params: 0
  14. ----------------------------------------------------------------
  15. Input size (MB): 3.00
  16. Forward/backward pass size (MB): 78.00
  17. Params size (MB): 0.27
  18. Estimated Total Size (MB): 81.27
  19. ----------------------------------------------------------------

是不是和剪枝之前实际上是一样的,可能会减少运算,但是似乎好像知乎大神提到的被证明运算也没啥提升

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/154266
推荐阅读
相关标签
  

闽ICP备14008679号