当前位置:   article > 正文

模型优化之模型剪枝

模型剪枝

一、概述
模型剪枝按照结构划分,主要包括结构化剪枝和非结构化剪枝:
(1)结构化剪枝:剪掉神经元节点之间的不重要的连接。相当于把权重矩阵中的单个权重值设置为0。
在这里插入图片描述
(2)非结构化剪枝:把权重矩阵中某个神经元节点去掉,则和神经元相连接的突触也要全部去除。相当于同时去除权重矩阵中的某一行和列。如何判断神经元节点的重要程度呢?可以通过计算神经元对应的行和列的权重值的平方和的根的大小进行排序,把排序在后面一定比例的神经元节点去掉
在这里插入图片描述
二、pytorch中模型剪枝:
Pytorch中模型的剪枝方法有三种,局部剪枝、全局剪枝和自定义剪枝。与剪枝有关的接口封装在torch.nn.utils.prune中。接下来开始演示三种剪枝在LeNet网络中的应用效果,我们首先给出LeNet网络结构。

import torch
from torch import nn

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120) 
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

(1)局部剪枝
在本人理解就是一层一层的单独剪枝,下面代码还附有多参数多网络结构剪枝:

def part_cut(model):
    '''
    ######################################局部剪枝#########################################
    剪枝之后会产生一个mask
    剪枝api:prune.random_unstructured(layer1, name="weight", amount=0.3)
            amount:剪枝的比例
            layer1:需要剪的层对象
            name:指定剪的权重还是偏执
    剪枝固化api:prune.remove(layer1, 'weight')
            参数不用过多介绍,功能是剪枝后的模型固化(永久化)
    '''
    layer1 = model.conv1
    print("--------------------------------------剪枝前----------------------------------")
    # print(list(layer1.named_parameters()))
    # print(list(layer1.named_buffers()))
    prune.random_unstructured(layer1, name="weight", amount=0.3)
    print("--------------------------------------剪枝后-----------------------------------")
    # print(list(layer1.named_parameters()))
    # print(list(layer1.named_buffers()))
    prune.remove(layer1, 'weight')
    print("-------------------------------------模型固化后---------------------------------")
    # print(list(layer1.named_parameters()))
    # print(list(layer1.named_buffers()))


    '''-------------------------------------多参数多网络结构剪枝---------------------------------'''
    for name, module in model.named_modules():
        print(name,module)
        # prune 20% of connections in all 2D-conv layers
        if isinstance(module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=0.2)
            prune.remove(module, 'weight')
        # prune 40% of connections in all linear layers
        elif isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=0.4)
            prune.remove(module, 'weight')

    print(dict(model.named_buffers()).keys())  # to verify that all masks exist
    return 0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39

(2)全局剪枝:
剪枝所占比例是按照所有参数来算的,不是按照每层的数量来算的,剪枝时候也按整体来算。

def glob_cut(model):
    '''
    全局剪枝:
    '''
    parameters_to_prune = (
        (model.conv1, 'weight'),
        (model.conv2, 'weight'),
        (model.fc1, 'weight'),
        (model.fc2, 'weight'),
        (model.fc3, 'weight'),
    )

    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=0.6,
    )
    print(list(model.named_parameters()))
    print(list(model.named_buffers()))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

(3)自定义剪枝
该方法不说了,饿了,要吃饭了,急的话参考下官方教程,最后有。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/190748
推荐阅读
相关标签
  

闽ICP备14008679号