赞
踩
torch.nn.utils.prune可以对模型进行剪枝,官方指导如下:
https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
直接上代码
首先建立模型网络:
- import torch
- import torch.nn as nn
- from torchsummary import summary
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- class SimpleNet(nn.Module):
- def __init__(self, num_classes=10):
- super(SimpleNet, self).__init__()
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=1, padding=1)
- self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=3, stride=1, padding=1)
- self.pool = nn.MaxPool2d(kernel_size=2)
- self.conv3 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, stride=1, padding=1)
- self.conv4 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=3, stride=1, padding=1)
- self.fc = nn.Linear(in_features=16 * 16 * 24, out_features=num_classes)
- def forward(self, input):
- output = self.conv1(input)
- output = nn.ReLU()(output)
- output = self.conv2(output)
- output = nn.ReLU()(output)
- output = self.pool(output)
- output = self.conv3(output)
- output = nn.ReLU()(output)
- output = self.conv4(output)
- output = nn.ReLU()(output)
- output = output.view(-1, 16 * 16 * 24)
- output = self.fc(output)
- return output
- model = SimpleNet().to(device=device)
看一下模型的 summary
summary(model, input_size=(3, 512, 512))
- ----------------------------------------------------------------
- Layer (type) Output Shape Param #
- ================================================================
- Conv2d-1 [-1, 12, 512, 512] 336
- Conv2d-2 [-1, 12, 512, 512] 1,308
- MaxPool2d-3 [-1, 12, 256, 256] 0
- Conv2d-4 [-1, 24, 256, 256] 2,616
- Conv2d-5 [-1, 24, 256, 256] 5,208
- Linear-6 [-1, 10] 61,450
- ================================================================
- Total params: 70,918
- Trainable params: 70,918
- Non-trainable params: 0
- ----------------------------------------------------------------
- Input size (MB): 3.00
- Forward/backward pass size (MB): 78.00
- Params size (MB): 0.27
- Estimated Total Size (MB): 81.27
- ----------------------------------------------------------------
打印一下模型结构各层的名称:
print(model.state_dict().keys())
结果:
odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias', 'conv4.weight', 'conv4.bias', 'fc.weight', 'fc.bias'])
接下来 对其进行剪枝操作:
- import torch.nn.utils.prune as prune
- parameters_to_prune = (
- (model.conv1, 'weight'),
- (model.conv2, 'weight'),
- (model.conv4, 'weight'),
- (model.fc, 'weight'),
- )
-
- prune.global_unstructured(
- parameters_to_prune,
- pruning_method=prune.L1Unstructured,
- amount=0.2,
- )
执行结束后,再打印一下:
print(model.state_dict().keys())
结果:
odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.bias', 'conv2.weight_orig', 'conv2.weight_mask', 'conv3.weight', 'conv3.bias', 'conv4.bias', 'conv4.weight_orig', 'conv4.weight_mask', 'fc.bias', 'fc.weight_orig', 'fc.weight_mask'])
我们发现剪枝结束后conv*.weight已经 消失了,出现了两个weight:weight_orig和weight_mask
其实weight_orig就是剪枝以前的weight,而weight_mask里面 只是0和1,0代表的是被剪枝的
打印一下:
print(model.state_dict()['conv1.weight_orig'])
- tensor([[[[1., 1., 1.],
- [1., 1., 1.],
- [0., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]]],
-
-
- [[[0., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]]],
-
-
- [[[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]]],
-
-
- [[[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]]],
-
-
- [[[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]]],
-
-
- [[[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 0.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]]],
-
-
- [[[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]]],
-
-
- [[[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 0.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]]],
-
-
- [[[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 0.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]]],
-
-
- [[[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 0.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]]],
-
-
- [[[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]]],
-
-
- [[[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]],
-
- [[1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]]]], device='cuda:0')
- prune.remove(module,
剪枝后,其实还是比较鸡肋的,因为只是剪之后的神经元相当于置零了,模型大小不会变,下面打印一下,有点dropout的意思了
- ----------------------------------------------------------------
- Layer (type) Output Shape Param #
- ================================================================
- Conv2d-1 [-1, 12, 512, 512] 336
- Conv2d-2 [-1, 12, 512, 512] 1,308
- MaxPool2d-3 [-1, 12, 256, 256] 0
- Conv2d-4 [-1, 24, 256, 256] 2,616
- Conv2d-5 [-1, 24, 256, 256] 5,208
- Linear-6 [-1, 10] 61,450
- ================================================================
- Total params: 70,918
- Trainable params: 70,918
- Non-trainable params: 0
- ----------------------------------------------------------------
- Input size (MB): 3.00
- Forward/backward pass size (MB): 78.00
- Params size (MB): 0.27
- Estimated Total Size (MB): 81.27
- ----------------------------------------------------------------
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。