赞
踩
一、概述
模型剪枝按照结构划分,主要包括结构化剪枝和非结构化剪枝:
(1)结构化剪枝:剪掉神经元节点之间的不重要的连接。相当于把权重矩阵中的单个权重值设置为0。
(2)非结构化剪枝:把权重矩阵中某个神经元节点去掉,则和神经元相连接的突触也要全部去除。相当于同时去除权重矩阵中的某一行和列。如何判断神经元节点的重要程度呢?可以通过计算神经元对应的行和列的权重值的平方和的根的大小进行排序,把排序在后面一定比例的神经元节点去掉
二、pytorch中模型剪枝:
Pytorch中模型的剪枝方法有三种,局部剪枝、全局剪枝和自定义剪枝。与剪枝有关的接口封装在torch.nn.utils.prune中。接下来开始演示三种剪枝在LeNet网络中的应用效果,我们首先给出LeNet网络结构。
import torch
from torch import nn
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
(1)局部剪枝
在本人理解就是一层一层的单独剪枝,下面代码还附有多参数多网络结构剪枝:
def part_cut(model):
'''
######################################局部剪枝#########################################
剪枝之后会产生一个mask
剪枝api:prune.random_unstructured(layer1, name="weight", amount=0.3)
amount:剪枝的比例
layer1:需要剪的层对象
name:指定剪的权重还是偏执
剪枝固化api:prune.remove(layer1, 'weight')
参数不用过多介绍,功能是剪枝后的模型固化(永久化)
'''
layer1 = model.conv1
print("--------------------------------------剪枝前----------------------------------")
# print(list(layer1.named_parameters()))
# print(list(layer1.named_buffers()))
prune.random_unstructured(layer1, name="weight", amount=0.3)
print("--------------------------------------剪枝后-----------------------------------")
# print(list(layer1.named_parameters()))
# print(list(layer1.named_buffers()))
prune.remove(layer1, 'weight')
print("-------------------------------------模型固化后---------------------------------")
# print(list(layer1.named_parameters()))
# print(list(layer1.named_buffers()))
'''-------------------------------------多参数多网络结构剪枝---------------------------------'''
for name, module in model.named_modules():
print(name,module)
# prune 20% of connections in all 2D-conv layers
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
prune.remove(module, 'weight')
# prune 40% of connections in all linear layers
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
prune.remove(module, 'weight')
print(dict(model.named_buffers()).keys()) # to verify that all masks exist
return 0
(2)全局剪枝:
剪枝所占比例是按照所有参数来算的,不是按照每层的数量来算的,剪枝时候也按整体来算。
def glob_cut(model):
'''
全局剪枝:
'''
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.6,
)
print(list(model.named_parameters()))
print(list(model.named_buffers()))
(3)自定义剪枝
该方法不说了,饿了,要吃饭了,急的话参考下官方教程,最后有。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。