- 本文是个人使用Pytorch进行超参数调优、量化、剪枝的电子笔记,由于水平有限,难免出现错漏,敬请批评改正。
- 更多精彩内容,可点击进入我的个人主页查看
- Python是一种跨平台的计算机程序设计语言。是一个高层次的结合了解释性、编译性、互动性和面向对象的脚本语言。最初被设计用于编写自动化脚本(shell),随着版本的不断更新和语言新功能的添加,越多被用于独立的、大型项目的开发。
- PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能:1、具有强大的GPU加速的张量计算(如NumPy)。2、包含自动求导系统的深度神经网络。
- Python 3.x (面向对象的高级语言)
- PyTorch(Python第三方库)
- 超参数(hyper parameters):在深度学习模型,需要人为设置的参数,比如学习率lr和批次大小batch_size。
- 在Python中,有一个Ray Tune的包可以管理超参数调优。
pip install tensorboardX
pip install ray
Requirement already satisfied: tensorboardX in /opt/conda/lib/python3.7/site-packages (2.5.1) Collecting protobuf<=3.20.1,>=3.8.0 Downloading protobuf-3.20.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB) [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m [?25hRequirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from tensorboardX) (1.21.6) Installing collected packages: protobuf Attempting uninstall: protobuf Found existing installation: protobuf 3.20.3 Uninstalling protobuf-3.20.3: Successfully uninstalled protobuf-3.20.3 [31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. tensorflow-io 0.21.0 requires tensorflow-io-gcs-filesystem==0.21.0, which is not installed. beatrix-jupyterlab 3.1.7 requires google-cloud-bigquery-storage, which is not installed. tfx-bsl 1.9.0 requires google-api-python-client<2,>=1.7.11, but you have google-api-python-client 2.52.0 which is incompatible. tfx-bsl 1.9.0 requires tensorflow!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,<3,>=1.15.5, but you have tensorflow 2.6.4 which is incompatible. tensorflow 2.6.4 requires h5py~=3.1.0, but you have h5py 3.7.0 which is incompatible. tensorflow 2.6.4 requires numpy~=1.19.2, but you have numpy 1.21.6 which is incompatible. tensorflow 2.6.4 requires typing-extensions<3.11,>=3.7, but you have typing-extensions 4.1.1 which is incompatible. tensorflow-transform 1.9.0 requires tensorflow!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,<2.10,>=1.15.5, but you have tensorflow 2.6.4 which is incompatible. tensorflow-serving-api 2.9.0 requires tensorflow<3,>=2.9.0, but you have tensorflow 2.6.4 which is incompatible. ortools 9.5.2237 requires protobuf>=4.21.5, but you have protobuf 3.20.1 which is incompatible. onnx 1.13.0 requires protobuf<4,>=3.20.2, but you have protobuf 3.20.1 which is incompatible. nnabla 1.32.1 requires protobuf<=3.19.4; platform_system != "Windows", but you have protobuf 3.20.1 which is incompatible. google-api-core 1.33.2 requires protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<4.0.0dev,>=3.19.5, but you have protobuf 3.20.1 which is incompatible. gcsfs 2022.5.0 requires fsspec==2022.5.0, but you have fsspec 2023.1.0 which is incompatible. apache-beam 2.40.0 requires dill<0.3.2,>=, but you have dill 0.3.6 which is incompatible.[0m[31m [0mSuccessfully installed protobuf-3.20.1 [33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m[33m [0mRequirement already satisfied: ray in /opt/conda/lib/python3.7/site-packages (2.2.0) Requirement already satisfied: attrs in /opt/conda/lib/python3.7/site-packages (from ray) (21.4.0) Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from ray) (2.28.1) Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /opt/conda/lib/python3.7/site-packages (from ray) (1.0.4) Requirement already satisfied: grpcio>=1.32.0 in /opt/conda/lib/python3.7/site-packages (from ray) (1.51.1) Requirement already satisfied: aiosignal in /opt/conda/lib/python3.7/site-packages (from ray) (1.2.0) Requirement already satisfied: click>=7.0 in /opt/conda/lib/python3.7/site-packages (from ray) (8.1.3) Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /opt/conda/lib/python3.7/site-packages (from ray) (3.20.1) Requirement already satisfied: numpy>=1.16 in /opt/conda/lib/python3.7/site-packages (from ray) (1.21.6) Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from ray) (3.7.1) Requirement already satisfied: jsonschema in /opt/conda/lib/python3.7/site-packages (from ray) (4.6.1) Requirement already satisfied: virtualenv>=20.0.24 in /opt/conda/lib/python3.7/site-packages (from ray) (20.17.1) Requirement already satisfied: pyyaml in /opt/conda/lib/python3.7/site-packages (from ray) (6.0) Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from ray) (4.1.1) Requirement already satisfied: frozenlist in /opt/conda/lib/python3.7/site-packages (from ray) (1.3.0) Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from click>=7.0->ray) (6.0.0) Requirement already satisfied: distlib<1,>=0.3.6 in /opt/conda/lib/python3.7/site-packages (from virtualenv>=20.0.24->ray) (0.3.6) Requirement already satisfied: platformdirs<3,>=2.4 in /opt/conda/lib/python3.7/site-packages (from virtualenv>=20.0.24->ray) (2.5.1) Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /opt/conda/lib/python3.7/site-packages (from jsonschema->ray) (0.18.1) Requirement already satisfied: importlib-resources>=1.4.0 in /opt/conda/lib/python3.7/site-packages (from jsonschema->ray) (5.10.2) Requirement already satisfied: charset-normalizer<3,>=2 in /opt/conda/lib/python3.7/site-packages (from requests->ray) (2.1.0) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->ray) (3.3) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->ray) (1.26.14) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->ray) (2022.12.7) Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->click>=7.0->ray) (3.8.0) [33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m[33m [0m
import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self, nodes_1=120, nodes_2=84): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, nodes_1) # 配置fc1中的节点 self.fc2 = nn.Linear(nodes_1, nodes_2) # 配置fc2中的节点 self.fc3 = nn.Linear(nodes_2, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
from ray import tune
import numpy as np
config = {
"nodes_1": tune.sample_from(
lambda _: 2 ** np.random.randint(2, 9)), # tune.sample_from()和lambda函数来定义搜索空间
"nodes_2": tune.sample_from(
lambda _: 2 ** np.random.randint(2, 9)),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([2, 4, 8, 16])
import torch import torchvision from torchvision import transforms def load_data(data_dir="./data"): train_transforms = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) test_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) trainset = torchvision.datasets.CIFAR10( root=data_dir, train=True, download=True, transform=train_transforms) testset = torchvision.datasets.CIFAR10( root=data_dir, train=False, download=True, transform=test_transforms) return trainset, testset
from torch import optim from torch import nn from torch.utils.data import random_split def train_model(config): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Net(config['nodes_1'],config['nodes_2']).to(device=device) # 可配置模型层 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=config['lr'], momentum=0.9) # 可配置学习率 trainset, testset = load_data() test_abs = int(len(trainset) * 0.8) train_subset, val_subset = random_split( trainset, [test_abs, len(trainset) - test_abs]) trainloader = torch.utils.data.DataLoader( train_subset, batch_size=int(config["batch_size"]), shuffle=True) # 可配置批次大小 valloader = torch.utils.data.DataLoader( val_subset, batch_size=int(config["batch_size"]), shuffle=True) # 可配置批次大小 for epoch in range(10): train_loss = 0.0 epoch_steps = 0 for data in trainloader: inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss += loss.item() val_loss = 0.0 total = 0 correct = 0 for data in valloader: with torch.no_grad(): # 临时将所有的require_grad设为False inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, predicted = torch.max( outputs.data, 1) total += labels.size(0) correct += \ (predicted == labels).sum().item() loss = criterion(outputs, labels) val_loss += loss.cpu().numpy() print(f'epoch: {epoch} ', f'train_loss: ', f'{train_loss/len(trainloader)}', f'val_loss: ', f'{val_loss/len(valloader)}', f'val_acc: {correct/total}') tune.report(loss=(val_loss / len(valloader)), accuracy=correct / total)
在运行Ray Tune之前,需要使用调度器和报告器。调度器(scheduler)用于搜索和选择超参数。报告器(reporter)用于查看结果。
from ray.tune import CLIReporter from ray.tune.schedulers import ASHAScheduler # 调度器,这里使用异步逐次减半算法(asynchronous successive halving algorithm,ASHA) scheduler = ASHAScheduler( metric="loss", # 指定要损失 mode="min", # 最小化损失 max_t=1, # 最大周期数 grace_period=1, reduction_factor=2) # 报告器,这里配置一个CLI报告器来报告损失、精度、训练迭代和每次运行是CLI上选择的超参数。 reporter = CLIReporter( metric_columns=["loss", "accuracy", "training_iteration"])
使用runn()方法运行Ray Tune
from functools import partial
result = tune.run(
# resources_per_trial={"cpu": 2, "gpu": 1}, # 每次训练的资源数
resources_per_trial={"cpu": 1, "gpu": 2},
num_samples=10, # 测试样本数
2023-02-02 02:05:24,885 INFO worker.py:1538 -- Started a local Ray instance. == Status == Current time: 2023-02-02 02:05:27 (running for 00:00:00.63) Memory usage on this node: 1.7/15.6 GiB Using AsyncHyperBand: num_stopped=0 Bracket: Iter 1.000: None Resources requested: 1.0/2 CPUs, 2.0/2 GPUs, 0.0/7.15 GiB heap, 0.0/3.58 GiB objects (0.0/1.0 accelerator_type:T4) Result logdir: /root/ray_results/train_model_2023-02-02_02-05-26 Number of trials: 10/10 (9 PENDING, 1 RUNNING) +-------------------------+----------+----------------+--------------+-------------+-----------+-----------+ | Trial name | status | loc | batch_size | lr | nodes_1 | nodes_2 | |-------------------------+----------+----------------+--------------+-------------+-----------+-----------| | train_model_0f495_00000 | RUNNING | | 8 | 0.00208745 | 64 | 256 | | train_model_0f495_00001 | PENDING | | 8 | 0.011537 | 64 | 32 | | train_model_0f495_00002 | PENDING | | 2 | 0.000202415 | 64 | 128 | | train_model_0f495_00003 | PENDING | | 8 | 0.000397489 | 128 | 8 | | train_model_0f495_00004 | PENDING | | 16 | 0.000670083 | 16 | 128 | | train_model_0f495_00005 | PENDING | | 16 | 0.00385978 | 8 | 32 | | train_model_0f495_00006 | PENDING | | 16 | 0.0461144 | 64 | 64 | | train_model_0f495_00007 | PENDING | | 2 | 0.0169714 | 4 | 256 | | train_model_0f495_00008 | PENDING | | 2 | 0.00162063 | 64 | 8 | | train_model_0f495_00009 | PENDING | | 16 | 0.00110084 | 64 | 128 | +-------------------------+----------+----------------+--------------+-------------+-----------+-----------+ 0%| | 0/170498071 [00:00<?, ?it/s] 0%| | 440320/170498071 [00:00<00:42, 4010336.03it/s] 4%|▎ | 6168576/170498071 [00:00<00:04, 34129025.00it/s] 9%|▊ | 14848000/170498071 [00:00<00:02, 57656698.87it/s] 14%|█▍ | 24336384/170498071 [00:00<00:02, 72154733.40it/s] 19%|█▉ | 33229824/170498071 [00:00<00:01, 78153262.54it/s] 25%|██▌ | 42682368/170498071 [00:00<00:01, 83684904.43it/s] 30%|███ | 51696640/170498071 [00:00<00:01, 85787329.41it/s] 36%|███▌ | 61095936/170498071 [00:00<00:01, 88391595.04it/s] 41%|████ | 70256640/170498071 [00:00<00:01, 89105431.15it/s] 47%|████▋ | 80496640/170498071 [00:01<00:00, 93190171.60it/s] 53%|█████▎ | 90400768/170498071 [00:01<00:00, 94972310.31it/s] 59%|█████▊ | 99906560/170498071 [00:01<00:00, 94834848.89it/s] 64%|██████▍ | 109395968/170498071 [00:01<00:00, 94666500.13it/s] 70%|██████▉ | 118912000/170498071 [00:01<00:00, 94807265.06it/s] 75%|███████▌ | 128464896/170498071 [00:01<00:00, 95015498.17it/s] 81%|████████ | 138288128/170498071 [00:01<00:00, 95976600.47it/s] 87%|████████▋ | 148819968/170498071 [00:01<00:00, 98781379.96it/s] 93%|█████████▎| 158700544/170498071 [00:01<00:00, 97807065.64it/s] 170499072it [00:01, 88559278.31it/s] 2023-02-02 02:16:05,160 INFO tune.py:763 -- Total run time: 638.90 seconds (638.46 seconds for the tuning loop). [2m[36m(func pid=1840)[0m epoch: 0 train_loss: 1.9899510860443115 val_loss: 1.7621070608139038 val_acc: 0.3433 == Status == Current time: 2023-02-02 02:16:05 (running for 00:10:38.48) Memory usage on this node: 4.4/15.6 GiB Using AsyncHyperBand: num_stopped=10 Bracket: Iter 1.000: -1.812330790552497 Resources requested: 0/2 CPUs, 0/2 GPUs, 0.0/7.15 GiB heap, 0.0/3.58 GiB objects (0.0/1.0 accelerator_type:T4) Result logdir: /root/ray_results/train_model_2023-02-02_02-05-26 Number of trials: 10/10 (10 TERMINATED) +-------------------------+------------+-----------------+--------------+-------------+-----------+-----------+---------+------------+----------------------+ | Trial name | status | loc | batch_size | lr | nodes_1 | nodes_2 | loss | accuracy | training_iteration | |-------------------------+------------+-----------------+--------------+-------------+-----------+-----------+---------+------------+----------------------| | train_model_0f495_00000 | TERMINATED | | 8 | 0.00208745 | 64 | 256 | 1.60968 | 0.4072 | 1 | | train_model_0f495_00001 | TERMINATED | | 8 | 0.011537 | 64 | 32 | 2.11739 | 0.182 | 1 | | train_model_0f495_00002 | TERMINATED | | 2 | 0.000202415 | 64 | 128 | 1.6399 | 0.3874 | 1 | | train_model_0f495_00003 | TERMINATED | | 8 | 0.000397489 | 128 | 8 | 1.84503 | 0.2984 | 1 | | train_model_0f495_00004 | TERMINATED | | 16 | 0.000670083 | 16 | 128 | 1.83333 | 0.3142 | 1 | | train_model_0f495_00005 | TERMINATED | | 16 | 0.00385978 | 8 | 32 | 1.69983 | 0.3735 | 1 | | train_model_0f495_00006 | TERMINATED | | 16 | 0.0461144 | 64 | 64 | 2.31054 | 0.1009 | 1 | | train_model_0f495_00007 | TERMINATED | | 2 | 0.0169714 | 4 | 256 | 2.31998 | 0.099 | 1 | | train_model_0f495_00008 | TERMINATED | | 2 | 0.00162063 | 64 | 8 | 1.79133 | 0.3197 | 1 | | train_model_0f495_00009 | TERMINATED | | 16 | 0.00110084 | 64 | 128 | 1.76211 | 0.3433 | 1 | +-------------------------+------------+-----------------+--------------+-------------+-----------+-----------+---------+------------+----------------------+
best_trial = result.get_best_trial(
"loss", "min", "last")
print("Best trial config: {}".format(
print("Best trial final validation loss:",
print("Best trial final validation accuracy:",
Best trial config: {'nodes_1': 64, 'nodes_2': 256, 'lr': 0.0020874538538687972, 'batch_size': 8}
Best trial final validation loss: 1.6096843579292297
Best trial final validation accuracy: 0.4072
- 量化是指用较低精度的数据计算和访问内存技术
import torch from torch import nn import torch.nn.functional as F class LeNet5(nn.Module): def __init__(self): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = F.max_pool2d( F.relu(self.conv1(x)), (2, 2)) x = F.max_pool2d( F.relu(self.conv2(x)), 2) x = x.view(-1, int(x.nelement() / x.shape[0])) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model = LeNet5()
for n, p in model.named_parameters():
print(n, ": ", p.dtype)
conv1.weight : torch.float32
conv1.bias : torch.float32
conv2.weight : torch.float32
conv2.bias : torch.float32
fc1.weight : torch.float32
fc1.bias : torch.float32
fc2.weight : torch.float32
fc2.bias : torch.float32
fc3.weight : torch.float32
fc3.bias : torch.float32
- 最快捷的量化方法:将是所有计算精度减半
model = model.half() # 模型精度减半
for n, p in model.named_parameters():
print(n, ": ", p.dtype)
conv1.weight : torch.float16
conv1.bias : torch.float16
conv2.weight : torch.float16
conv2.bias : torch.float16
fc1.weight : torch.float16
fc1.bias : torch.float16
fc2.weight : torch.float16
fc2.bias : torch.float16
fc3.weight : torch.float16
fc3.bias : torch.float16
- 实际上,我们一般不会用同样的方式量化每一个计算。而且,float16可能还不够,还需量化为更低的精度。
- Pytorch提供了另外3种量化模式:动态量化(Dynamic quantization)、后训练静态量化(Post-training static quantization)和量化感知训练(quantization-aware training,QAT)。
- 动态量化(Dynamic quantization)是最简单的一类量化。其动态地将激活转化为int8。计算中使用int8值,但会按浮点数格式向内存读写激活。
import torch.quantization
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
for n, p in quantized_model.named_parameters():
print(n, ": ", p.dtype)
conv1.weight : torch.float16
conv1.bias : torch.float16
conv2.weight : torch.float16
conv2.bias : torch.float16
- 后训练静态量化(Post-training static quantization)可以用来进一步减少延迟,其会观察训练中不同激活的分布,并决定推理时应当如何量化这些激活。这种量化允许我们在操作之间传递量化值,而不用在内存中来回转化float和int。
- 注:量化依赖于用来运行量化模型的后端。目前,对于CPU推理,只有x86(fbgemm)和ARM(qnnpack)支持量化操作。不过,后面的量化感知训练(quantization-aware training,QAT)使用全浮点数,在GPU和CPU上都能运行。
static_quant_model = LeNet5()
static_quant_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(static_quant_model, inplace=True)
torch.quantization.convert(static_quant_model, inplace=True)
(conv1): QuantizedConv2d(3, 6, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0)
(conv2): QuantizedConv2d(6, 16, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0)
(fc1): QuantizedLinear(in_features=400, out_features=120, scale=1.0, zero_point=0, qscheme=torch.per_channel_affine)
(fc2): QuantizedLinear(in_features=120, out_features=84, scale=1.0, zero_point=0, qscheme=torch.per_channel_affine)
(fc3): QuantizedLinear(in_features=84, out_features=10, scale=1.0, zero_point=0, qscheme=torch.per_channel_affine)
- 量化感知训练(quantization-aware training,QAT)通常可以得到最好的精度。在这种情况下,所有的权重和激活会在训练的前向和后向传播中“假量化”(fake quantized)。Float值取整为相应的int8,不过,计算仍用浮点数完成,即,会让权重调整“感知到”将在训练期间量化。
qat_model = LeNet5()
qat_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(qat_model, inplace=True)
torch.quantization.convert(qat_model, inplace=True)
/opt/conda/lib/python3.7/site-packages/torch/ao/quantization/utils.py:211: UserWarning: must run observer before calling calculate_qparams. Returning default values.
"must run observer before calling calculate_qparams. " +
(conv1): QuantizedConv2d(3, 6, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0)
(conv2): QuantizedConv2d(6, 16, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0)
(fc1): QuantizedLinear(in_features=400, out_features=120, scale=1.0, zero_point=0, qscheme=torch.per_channel_affine)
(fc2): QuantizedLinear(in_features=120, out_features=84, scale=1.0, zero_point=0, qscheme=torch.per_channel_affine)
(fc3): QuantizedLinear(in_features=84, out_features=10, scale=1.0, zero_point=0, qscheme=torch.per_channel_affine)
- 注:Pytorch的量化功能还在继续开发中,目前还处于beta测试阶段。
- 剪枝(Pruning)是建设模型参数个数而且对性能影响最小的一种技术。这使得可以用更小的内存、更小的处理器和更少的硬件资源来部署模型。
import torch from torch import nn import torch.nn.functional as F class LeNet5(nn.Module): def __init__(self): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = F.max_pool2d( F.relu(self.conv1(x)), (2, 2)) x = F.max_pool2d( F.relu(self.conv2(x)), 2) x = x.view(-1, int(x.nelement() / x.shape[0])) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
- LeNet5有5个子模块:conv1、conv2、fc1、fc2、fc3。模型参数包括权重和偏置,可以用named_parameters()方法查看这些参数。
device = torch.device("cuda" if
torch.cuda.is_available() else "cpu")
model = LeNet5().to(device)
[('weight', Parameter containing: tensor([[[[ 0.0807, -0.0330, -0.0133, 0.0424, 0.0620], [ 0.0338, 0.1058, -0.1049, -0.0152, -0.0697], [-0.0215, -0.1002, 0.0803, 0.0423, -0.0491], [ 0.0769, 0.0831, -0.0053, 0.0519, 0.0787], [ 0.0449, 0.0963, 0.1036, -0.0119, 0.0780]], [[ 0.0206, 0.0409, -0.0407, -0.0231, -0.0977], [-0.0069, 0.0188, -0.0466, -0.0172, 0.0372], [-0.0804, 0.0902, 0.1082, -0.0192, 0.0477], [ 0.0057, 0.0447, -0.0272, -0.1057, 0.1135], [ 0.1046, 0.0197, -0.0288, -0.0803, 0.0797]], [[ 0.0961, -0.0309, -0.0433, 0.0510, -0.0408], [ 0.0218, -0.0093, 0.0297, -0.0055, 0.0561], [ 0.0161, -0.0166, 0.0739, -0.0938, 0.0317], [-0.0573, 0.0727, -0.0758, -0.0565, 0.0878], [-0.0913, -0.0770, -0.0225, 0.0828, 0.1036]]],
[[[-0.0178, 0.1112, 0.0027, 0.0701, -0.0215], [ 0.0193, -0.1126, -0.0067, -0.0459, -0.0953], [-0.0825, -0.0526, 0.0168, 0.0145, -0.0125], [-0.0877, 0.0207, 0.0051, -0.0489, 0.0720], [ 0.0074, 0.0232, -0.0267, -0.0912, -0.0016]], [[ 0.1091, 0.0140, 0.0271, -0.0390, -0.0958], [ 0.0068, 0.0734, -0.0895, 0.0667, -0.0704], [ 0.0640, 0.0240, -0.0811, -0.1071, -0.0046], [-0.0286, -0.0557, 0.0219, -0.0797, 0.0399], [ 0.0951, -0.0194, 0.0160, -0.1102, 0.0037]], [[ 0.0625, 0.0565, 0.1011, -0.0599, -0.0048], [-0.0233, -0.0210, 0.0191, 0.0663, -0.0904], [ 0.1000, -0.0677, -0.0137, 0.0629, 0.1139], [-0.0315, 0.0504, -0.1096, -0.0365, -0.0279], [-0.0512, 0.0821, -0.0359, 0.0349, -0.0828]]],
[[[-0.0085, 0.0708, 0.0927, -0.0134, 0.1040], [ 0.1011, 0.0380, -0.0932, 0.0248, -0.0573], [ 0.0597, 0.0865, -0.0899, 0.0878, 0.1042], [-0.0423, 0.0050, -0.0296, -0.0998, 0.0412], [ 0.0276, 0.0230, 0.0052, 0.0527, 0.0328]], [[-0.0116, 0.0606, 0.0782, 0.1016, -0.0558], [-0.0879, -0.0913, 0.0039, -0.0486, 0.0302], [-0.1125, 0.0397, 0.1011, 0.1051, -0.0013], [ 0.0604, 0.0398, -0.0025, -0.0450, 0.0254], [-0.0317, -0.0395, 0.0556, -0.0077, -0.0087]], [[-0.0811, 0.1145, -0.0649, -0.0265, 0.1032], [ 0.0794, -0.0024, -0.0237, 0.0598, -0.0944], [ 0.1095, -0.0970, -0.0178, -0.0926, 0.0684], [ 0.0907, 0.0652, -0.0588, 0.0637, 0.0302], [ 0.1132, -0.0547, 0.0659, 0.0479, 0.1095]]],
[[[ 0.0822, -0.0710, 0.0067, 0.0500, 0.0274], [-0.0423, -0.0655, 0.0858, 0.0685, 0.1024], [-0.0693, -0.0567, 0.0308, 0.0589, 0.0455], [ 0.0904, -0.0133, -0.0870, -0.0671, 0.1025], [-0.0686, -0.0085, 0.0624, 0.1017, -0.0239]], [[ 0.0907, -0.0579, 0.0706, 0.0307, -0.1153], [-0.0122, -0.0377, -0.0445, -0.0538, 0.0338], [-0.0725, -0.1115, 0.0604, -0.0136, 0.0975], [ 0.0648, 0.0492, -0.0770, 0.0845, 0.0173], [-0.0533, 0.0212, 0.0801, -0.1113, 0.0864]], [[-0.0126, -0.0099, 0.0226, -0.1111, 0.0698], [ 0.0987, -0.0507, 0.0460, 0.0509, -0.1049], [ 0.0899, 0.0256, -0.0954, -0.0310, 0.1025], [-0.0658, -0.0842, -0.0705, -0.0690, -0.0596], [ 0.0873, 0.0355, -0.0280, 0.0308, 0.0801]]],
[[[ 0.0558, 0.0660, -0.0859, 0.0719, -0.0570], [-0.0832, 0.1147, 0.0418, -0.0291, -0.0384], [-0.1143, 0.0522, 0.0428, 0.0614, -0.0119], [ 0.0641, 0.0930, 0.0407, -0.0353, -0.0657], [ 0.0042, 0.0132, -0.0557, -0.0803, -0.0464]], [[-0.0611, -0.0598, -0.0383, 0.0453, 0.0462], [ 0.1045, -0.0514, -0.0189, -0.0014, -0.0054], [-0.0372, 0.0966, 0.0741, 0.0870, 0.1023], [-0.0117, -0.0157, -0.1145, 0.0599, -0.0392], [-0.0648, 0.0903, 0.0471, -0.0930, -0.1113]], [[-0.0528, 0.0461, -0.0693, -0.0424, 0.0825], [-0.0244, -0.0363, 0.0469, 0.0252, -0.0127], [ 0.0590, 0.0485, 0.0280, -0.0457, 0.0224], [-0.0290, -0.0319, 0.0266, -0.1103, 0.0002], [-0.1103, -0.0315, 0.0587, -0.0035, -0.0100]]],
[[[ 0.0358, -0.0845, -0.1016, 0.1149, 0.0869], [ 0.0829, -0.0099, 0.0339, -0.1071, 0.0679], [ 0.0901, -0.0212, 0.0468, 0.0042, -0.0929], [-0.0648, -0.0580, -0.0112, -0.0113, -0.0682], [ 0.0406, -0.0807, 0.0634, 0.0170, -0.1031]], [[-0.0955, -0.0185, -0.0148, 0.0005, -0.0372], [-0.0207, 0.1041, -0.0922, 0.0103, 0.0424], [-0.0581, 0.1128, 0.0292, 0.0042, -0.0814], [ 0.0882, -0.0714, -0.0918, -0.1019, -0.0829], [ 0.0179, 0.0246, -0.0940, 0.0159, 0.0944]], [[ 0.0258, 0.0743, 0.0390, -0.1051, -0.0090], [-0.0187, -0.0850, -0.0034, -0.0107, -0.0168], [-0.0350, 0.0346, 0.0705, -0.0884, 0.0876], [-0.0850, -0.0734, -0.1152, 0.0609, -0.1100], [ 0.0363, -0.0489, -0.0183, -0.0161, 0.0226]]]], requires_grad=True)), ('bias', Parameter containing: tensor([ 0.0973, 0.0784, 0.0344, -0.0536, 0.0964, -0.0110], requires_grad=True))]
- 局部剪枝(Local pruning)是指只剪枝模型中一个指定的部分。
import torch.nn.utils.prune as prune
Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
model = LeNet5().to(device)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
amount=0.3) # 将所有Conv2d层剪枝30%
elif isinstance(module, torch.nn.Linear):
amount=0.5) # 将所有Linear层剪枝50%
- 全局剪枝(global pruning)是对整个模型进行剪枝。例如,将这个模型中的所有参数剪枝25%,示例如下。
model = LeNet5().to(device)
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
class MyPruningMethod(prune.BasePruningMethod):
PRUNING_TYPE = 'unstructured' # 剪枝类型
def compute_mask(self, t, default_mask):
mask = default_mask.clone()
mask.view(-1)[::2] = 0
return mask
def my_unstructured(module, name):
MyPruningMethod.apply(module, name)
return module
model = LeNet5().to(device)
my_unstructured(model.fc1, name='bias')
Linear(in_features=400, out_features=120, bias=True)
[1] https://docs.ray.io/en/master/index.html
[2] https://pytorch.org/docs/stable/quantization.html
[3] https://www.pytorchacademy.com/bundles/pytorch-academy
[4] Joe Papa. PyTorch Pocket Reference. 北京: 中国电力出版社,2022
- 更多精彩内容,可点击进入我的个人主页查看
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。