赞
踩
模型剪枝就是根据神经元的贡献程度对网络中的神经元进行排名,可以从网络中移除排名较低的神经元,从而形成一个更小、更快的网络模型。
基本思想示意图:
模型剪枝根据神经元权重的L1/L2范数来进行排序。剪枝后,准确率会下降,网络通常是训练-剪枝-训练-剪枝(trained-pruned-trained-pruned)迭代恢复的。如果我们一次剪枝太多,网络可能会被破坏得无法恢复。所以在实践中,这是一个迭代过程——通常被称为迭代剪枝(iterative pruning):剪枝/训练/重复。
训练时使用L1正则化能对参数进行稀疏作用
L1:稀疏与特征选择;L2:平滑特征
预训练:
原始网络模型需要满足 Conv2d+BatchNorm2d+ReLU 作为一个整体
训练时在BatchNorm层增加L1正则进行稀疏训练,得到每个特征图对应的gamma值,即γ越小,其对应的特征图越不重要,为了使得γ 能有特征选择的作用,引入L1正则来控制γ
- def updateBN(model, s):
- for m in model.modules():
- if isinstance(m, nn.BatchNorm2d):
- # L1 大于0为1 小于0为-1 0还是0
- m.weight.grad.data.add_(s*torch.sign(m.weight.data))
-
- '''
- '''
- #在训练函数中调用
- '''
- '''
- loss.backward()
- #剪枝优化
- sr = 0.0001
- if sr:
- updateBN(self.model,sr)
- self.optimizer.step()
剪枝:
加载预训练模型,进行剪枝,然后保存剪枝后的模型
需要指定--percent 剪枝比例、--model 预训练的模型、--save 保存剪枝后的模型名称
- import os
- import argparse
- import torch
- import torch.nn as nn
- from torch.autograd import Variable
- from torchvision import datasets, transforms
-
- #from vgg import vgg
- from model.model import ASPNET
- import numpy as np
-
- # Prune settings
- parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')
- parser.add_argument('--dataset', type=str, default='cifar10',
- help='training dataset (default: cifar10)')
- parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
- help='input batch size for testing (default: 1000)')
- parser.add_argument('--no-cuda', action='store_true', default=False,
- help='disables CUDA training')
- parser.add_argument('--percent', type=float, default=0.5,
- help='scale sparse rate (default: 0.5)')
- parser.add_argument('--model', default='', type=str, metavar='PATH',
- help='path to raw trained model (default: none)')
- parser.add_argument('--save', default='', type=str, metavar='PATH',
- help='path to save prune model (default: none)')
- args = parser.parse_args()
- args.cuda = not args.no_cuda and torch.cuda.is_available()
-
- #model = vgg()
- model = ASPNET()
- if args.cuda:
- model.cuda()
- if args.model:
- if os.path.isfile(args.model):
- print("=> loading checkpoint '{}'".format(args.model))
- checkpoint = torch.load(args.model)
- args.start_epoch = checkpoint['epoch']
- best_prec1 = checkpoint['monitor_best']
- model.load_state_dict(checkpoint['state_dict'])
- print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
- .format(args.model, checkpoint['epoch'], best_prec1))
- else:
- print("=> no checkpoint found at '{}'".format(args.resume))
-
- print(model)
- total = 0 # 每层特征图个数 总和
- for m in model.modules():
- if isinstance(m, nn.BatchNorm2d):
- total += m.weight.data.shape[0]
-
- bn = torch.zeros(total) # 拿到每一个gamma值 每个特征图都会对应一个γ、β
- index = 0
- for m in model.modules():
- if isinstance(m, nn.BatchNorm2d):
- size = m.weight.data.shape[0]
- bn[index:(index+size)] = m.weight.data.abs().clone()
- index += size
-
- y, i = torch.sort(bn)
- thre_index = int(total * args.percent)
- thre = y[thre_index]
-
- pruned = 0
- cfg = []
- cfg_mask = []
- for k, m in enumerate(model.modules()):
- if isinstance(m, nn.BatchNorm2d):
- weight_copy = m.weight.data.clone()
- mask = weight_copy.abs().gt(thre).float().cuda() #.gt 比较前者是否大于后者
- pruned = pruned + mask.shape[0] - torch.sum(mask)
- m.weight.data.mul_(mask) # BN层gamma置0
- m.bias.data.mul_(mask) #
- cfg.append(int(torch.sum(mask)))
- cfg_mask.append(mask.clone())
- print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
- format(k, mask.shape[0], int(torch.sum(mask))))
- elif isinstance(m, nn.MaxPool2d):
- cfg.append('M')
-
- pruned_ratio = pruned/total
-
- print('Pre-processing Successful!')
-
-
- # 执行剪枝
- print(cfg)
- #newmodel = vgg(cfg=cfg) # 剪枝后的模型
- newmodel = ASPNET(net_name=cfg) # 剪枝后的模型
- newmodel.cuda()
- # 为剪枝后的模型赋值权重
- layer_id_in_cfg = 0
- start_mask = torch.ones(1) #输入
- end_mask = cfg_mask[layer_id_in_cfg] #输出
- for [m0, m1] in zip(model.modules(), newmodel.modules()):
- if isinstance(m0, nn.BatchNorm2d):
- idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) # 赋值
- m1.weight.data = m0.weight.data[idx1].clone()
- m1.bias.data = m0.bias.data[idx1].clone()
- m1.running_mean = m0.running_mean[idx1].clone()
- m1.running_var = m0.running_var[idx1].clone()
- layer_id_in_cfg += 1
- start_mask = end_mask.clone() #下一层的
- if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
- end_mask = cfg_mask[layer_id_in_cfg] #输出
- elif isinstance(m0, nn.Conv2d):
- idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
- idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
- print(idx0)
- print(idx1)
- if idx0.size == 1:
- idx0 = np.resize(idx0, (1,))
- if idx1.size == 1:
- idx1 = np.resize(idx1, (1,))
- #print('In shape: {:d} Out shape:{:d}'.format(idx0.shape[0], idx1.shape[0]))
- w = m0.weight.data[:, idx0, :, :].clone() #拿到原始训练好权重
- w = w[idx1, :, :, :].clone()
- m1.weight.data = w.clone() # 将所需权重赋值到剪枝后的模型
- # m1.bias.data = m0.bias.data[idx1].clone()
- elif isinstance(m0, nn.Linear):
- #idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
- #m1.weight.data = m0.weight.data[:, idx0].clone()
- m1.weight.data = m0.weight.data.clone()
-
-
- torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, args.save)
-
- print(newmodel)
剪枝前参数大小 27102,剪枝后参数大小 7185
使用剪枝后的模型再训练:
使用剪枝后的网络架构,同时加载剪枝后的模型参数进行初始化
- refine = 剪枝后的模型
- if refine:
- checkpoint = torch.load(refine)
- print(checkpoint['cfg'])
- model = ASPNET(net_name=checkpoint['cfg'])#使用剪枝后的网络架构
- model.cuda()
- model.load_state_dict(checkpoint['state_dict'])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。