赞
踩
Pytorch提供了两种方式进行保存模型。
import torch import torchvision from torch import nn from torch.nn import Sequential, Conv2d, MaxPool2d vgg16 = torchvision.models.vgg16(pretrained=False) # 保存方式1:模型结构+模型参数 torch.save(vgg16, "vgg16_method1.pth") # 保存模型结构及参数 # 保存方式2:模型参数,保存成字典的形式(官方推荐) torch.save(vgg16.state_dict(), "vgg16_method2.pth") # 陷阱1:方式1保存模型,陷阱在加载处 class Model(nn.Module): def __init__(self) -> None: super().__init__() # 初始化父类属性 self.model1 = Sequential( Conv2d(3, 32, 5, stride=1, padding=2), MaxPool2d(2), Conv2d(32, 32, 5, padding=2), MaxPool2d(2), Conv2d(32, 64, 5, padding=2), MaxPool2d(2), ) def forward(self, x): x = self.model1(x) return x model = Model() torch.save(model, "model_method.pth") # 保存模型结构及参数
Pytorch提供了两种方式进行读取模型。
注意:读取模型时,必须引入该模型结构的class定义,否则加载模型时报错缺少类定义。
import torch import torchvision.models from Model import Model # 引入模型类,防止加载自定义模型报错 # 方式1:加载模型 model1 = torch.load("vgg16_method1.pth") # 加载模型结构及参数 print("方式1:\n", model1) # 打印模型网络结构 # 方式2:加载模型 model_data = torch.load("vgg16_method2.pth") # 加载模型参数 print("方式2:\n", model_data) # 打印模型网络参数 vgg16 = torchvision.models.vgg16(pretrained=False) # vgg16网络模型 vgg16.load_state_dict(model_data) # 将模型参数加载到模型里 # 陷阱1:导入模型时报错缺少类定义(AttributeError) # 解决方法:在当前文件加载import该类 from Model import Model Model.py文件里定义了Model类 model = torch.load("model_method.pth") # 加载模型结构及参数 print("陷阱1:\n", model)
输出:
方式1: VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace=True) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace=True) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace=True) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=1000, bias=True) ) ) 方式2: OrderedDict([('features.0.weight', tensor([[[[ 3.7737e-04, 4.1346e-02, 6.0702e-02], [ 7.0125e-02, 3.7126e-02, -7.6289e-02], [ 1.2145e-01, 4.2173e-02, -1.1606e-01]], [[-2.3715e-02, 1.9658e-02, -7.4128e-02], [-2.9713e-02, 3.6599e-03, 9.9301e-03], [-4.9300e-02, 5.1934e-02, 1.0522e-01]], [[ 1.4076e-02, 5.1264e-02, -5.4800e-02], [-3.5250e-02, 2.0560e-02, -2.7887e-03], [ 2.2512e-02, 5.9779e-02, 4.9314e-02]]], [[[-8.1202e-03, -4.0062e-02, -4.1275e-02], [ 1.3463e-02, -4.1142e-02, 1.1663e-01], [-1.6806e-02, 7.7193e-02, 5.9772e-02]], [[-3.7491e-03, 7.0595e-02, 3.9575e-02], [-1.7332e-01, 5.7054e-02, 1.2022e-01], [ 1.6720e-02, -1.2557e-02, 8.1462e-02]], [[ 2.0320e-02, -9.4389e-03, -2.6056e-02], [-9.8172e-03, 1.4638e-01, -2.9588e-04], [ 1.9194e-02, -5.7499e-02, 4.5579e-02]]], [[[ 8.1152e-02, -3.3212e-02, 4.4831e-02], [-2.5436e-02, -3.9699e-02, -4.9673e-02], [-2.0726e-02, 1.9308e-02, 1.5040e-02]], [[ 1.0469e-01, 2.3499e-02, 2.0060e-02], [-9.3836e-02, -3.8625e-02, -4.0413e-02], [ 7.2539e-02, 2.8679e-02, 3.7398e-02]], [[-1.9462e-03, -9.2730e-02, 2.1433e-03], [-1.2013e-01, 6.4750e-02, 8.3451e-02], [-8.4348e-02, 5.1198e-02, -1.5884e-01]]], ..., ..., ..., [-0.0068, 0.0025, 0.0026, ..., -0.0150, -0.0085, -0.0084], [ 0.0023, -0.0015, -0.0213, ..., 0.0131, -0.0111, -0.0071], [ 0.0091, -0.0014, -0.0073, ..., -0.0146, 0.0060, 0.0087]])), ('classifier.0.bias', tensor([0., 0., 0., ..., 0., 0., 0.])), ('classifier.3.weight', tensor([[-0.0036, 0.0033, 0.0061, ..., 0.0100, 0.0028, -0.0114], [-0.0017, -0.0052, 0.0002, ..., 0.0097, 0.0015, 0.0009], [ 0.0189, -0.0090, 0.0017, ..., -0.0046, 0.0094, -0.0055], ..., [-0.0081, -0.0144, 0.0065, ..., 0.0009, -0.0081, -0.0141], [ 0.0085, 0.0051, 0.0092, ..., 0.0080, -0.0117, 0.0045], [-0.0038, -0.0033, 0.0118, ..., -0.0112, -0.0121, -0.0056]])), ('classifier.3.bias', tensor([0., 0., 0., ..., 0., 0., 0.])), ('classifier.6.weight', tensor([[-0.0070, 0.0144, 0.0028, ..., 0.0072, 0.0221, 0.0056], [ 0.0203, -0.0066, 0.0003, ..., 0.0057, -0.0002, 0.0077], [-0.0004, 0.0128, 0.0234, ..., 0.0073, 0.0079, 0.0003], ..., [-0.0023, 0.0004, -0.0097, ..., 0.0037, -0.0093, 0.0014], [-0.0048, -0.0078, -0.0077, ..., 0.0131, -0.0044, 0.0071], [-0.0050, -0.0099, -0.0006, ..., -0.0062, -0.0243, -0.0062]])), ('classifier.6.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))]) 陷阱1: Model( (model1): Sequential( (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) )
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。