当前位置:   article > 正文

7、PyTorch 参数初始化及模型保存、加载、微调及推理_pytorch 初始化

pytorch 初始化


一、PyTorch 参数初始化

参数初始化原理可参考此 blog:深度学习中的参数初始化

1.1、使用 torch.nn.init 进行初始化

# 针对不同类型的激活函数计算增益(方差变化尺度)
torch.nn.init.calculate_gain(nonlinearity, param=None)  # nonlinearity:激活函数名称; param: 激活函数参数
- Linear、Identity、Conv{1,2,3}D、Sigmoid:1
- Tanh:5/3
- ReLU:sqrt(2)
- Leaky Relu:sqrt(2 / (1 + negative_slope**2))
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2)


# Xavier 初始化(均匀分布和正态分布)
torch.nn.init.xavier_uniform_(tensor, gain=1.0)  # U(−a,a),其中 a=gain×sqrt(6/(fan_in+fan_out))
torch.nn.init.xavier_normal_(tensor, gain=1.0)  # N(0,std**2),其中 std=gain×sqrt(2/(fan_in+fan_out))
>>> w = torch.empty(3, 5)
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
>>> nn.init.xavier_normal_(w)

# MSRA 初始化(均匀分布和正态分布)
torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')  # U(−b,b),其中 b=gain×sqrt(3/fan_mode)
torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')  # N(0,std**2),其中 std=gain×sqrt(1/fan_mode)
- a:the negative slope of the rectifier used after this layer (only with 'leaky_relu') (used)
- mode:
	- 'fan_in' preserves the magnitude of the variance of the weights in the forward pass 
	- 'fan_out' preserves the magnitudes in the backwards pass
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')

# 随机初始化(均匀分布和正态分布)
torch.nn.init.uniform_(tensor, a=0.0, b=1.0)
torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
>>> w = torch.empty(3, 5)
>>> nn.init.uniform_(w)
>>> nn.init.normal_(w)


# 常量初始化
torch.nn.init.constant_(tensor, val)
>>> w = torch.empty(3, 5)
>>> nn.init.constant_(w, 0.3)


# 根据不同类型的层,设定不同的权值初始化方法 1
def initialize_weights(self):
	for m in self.modules():
		if isinstance(m, nn.Conv2d):
			torch.nn.init.xavier_normal_(m.weight.data)  
			if m.bias is not None:
				m.bias.data.zero_()
		elif isinstance(m, nn.BatchNorm2d):
			m.weight.data.fill_(1)
			m.bias.data.zero_()
		elif isinstance(m, nn.Linear):
			torch.nn.init.normal_(m.weight.data, 0, 0.01)
			# m.weight.data.normal_(0, 0.01)
			m.bias.data.zero_()

		# 按需自定义初始化方法 #
		if isinstance(m, nn.Conv2d):
			n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
			m.weight.data.normal_(0, math.sqrt(2. / n))  # 对 *.weight.data 赋值进行初始化

# 根据不同类型的层,设定不同的权值初始化方法 2
for name, params in net.named_parameters():
    if name.find('linear') != -1:  # find: 在字符串 name 中查找另一个字符串,找到了返回第一次出现的位置, 没找到返回 -1
        params[0]  # weights
        params[1]  # bias
    elif name.find('conv') != -1:
        pass
    elif name.find('norm') != -1:
        pass
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70

1.2、使用预训练模型进行初始化

  • 保存模型参数: 假设创建了一个 net = Net(),并且经过训练,可通过 torch.save(net.state_dict(), 'net_params.pth') 保存,也可从网上下载一些基础模型
  • 加载预训练模型参数: pretrained_dict = torch.load('net_params.pth')
  • 使用预训练模型对网络进行初始化:
net = Net() # 创建 net,最后的分类层名字改一下用于不同类别的分类任务
net_state_dict = net.state_dict() # 获取已创建 net 的 state_dict
pretrained_dict = torch.load(model_path)

# 1、将 pretrained_dict 里不属于 net_state_dict 的键剔除掉
modified_dict =  {k: v for k, v in pretrained_dict.items() if k in net_state_dict}  

# 2、用预训练模型的参数字典对新模型的参数字典 net_state_dict 进行更新
# 新网络,预训练模型中存在的 key 使用预训练模型的权重,不存在的 key 使用默认初始化的权重
net_state_dict.update(modified_dict)

# 3、将更新了参数的字典放回到网络中
net.load_state_dict(net_state_dict)


--------------------
# 简化版本:将 strict 置为 False,只加载预训练模型中有的 key 中的 value,其余部分使用默认初始化的权重
net = Net() # 创建 net,最后的分类层名字改一下用于不同类别的分类任务
pretrained_dict = torch.load(model_path)
net.load_state_dict(pretrained_dict, strict=False)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

二、PyTorch 模型保存和加载

2.1、模型保存与加载

  • 模型保存与加载主要通过 torch.savetorch.load 来完成,其中保存的内容只要是对象都可以;
  • 模型保存操作称为序列化;模型加载操作称为反序列化

在这里插入图片描述

  • 保存成 pth 文件后,可以使用 netron 查看序列化后的模型(下面只包含权重部分):包括名称、形状、值等在这里插入图片描述
  • 转换成 onnx 后,可以通过 netron 查看包含网络结构和权重的 onnx 文件(netron 中打开 Show Names 显示输入输出名称)在这里插入图片描述
# 模型保存函数
torch.save(obj: object, f: FILE_LIKE, 
			pickle_module: Any = pickle, pickle_protocol: int = DEFAULT_PROTOCOL, 
			_use_new_zipfile_serialization: bool = True) -> None:
- obj:对象
- f:输出路径(包含模型名)
- note:The 1.6 release of PyTorch switched `torch.save` to use a new
        zipfile-based file format. `torch.load` still retains the ability to
        load files in the old format. If for any reason you want `torch.save`
        to use the old format, pass the kwarg `_use_new_zipfile_serialization=False`.

# 模型加载函数
torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=False, **pickle_load_args)
- f:模型路径(包含模型名)
- map_location(a function, class:torch.device):指定模型加载到哪里,cpu or gpu(cuda:0)############### 第一种方法:只保存和加载模型参数(推荐,板端部署基本没啥问题) ###############
# 直接利用 torchvision.models 中自带的预训练模型,只需要在使用时赋予 pretrained 参数为 True 即可
the_model = torchvision.models.resnet18(pretrained=True)  
model_path = os.path.join(BASEDIR, "resnet18-5c106cde.pth")

# 保存模型参数:通过调用 pickle 序列化方法实现的
torch.save(the_model.state_dict(), model_path)

# 加载模型参数(先将权重从硬盘加载到内存;然后将内存中加载好的权重赋值给当前网络)
# 在没有 GPU 的设备上需要加上 map_location 参数在 CPU 上执行(也可以指定为 gpu)
# map_location 指定了参数存放位置,可选 cpu 或 gpu(cuda:0)
# 返回一个有序字典,可以用 keys() 和 values() 查看相应的键和值,用 items() 查看键值对
state_dict_load = torch.load(model_path, map_location="cpu") 
the_model.load_state_dict(state_dict_load)
# self.load_state_dict(pretrained_dict, strict=False)
# 可以将 strict 置为 False,只加载预训练模型中有的 key 中的 value,其余部分使用默认初始化的权重;这样下面就可以不用写了


############### 第二种方法:保存和加载整个模型(不推荐,板端部署可能有问题) ###############
# 由于特定的序列化的数据与其特定的类别(class)相绑定,并且在序列化的时候使用了固定的目录结构,
# 所以在很多情况下,如在其他的一些项目中使用,或者代码进行了较大的重构的时候,很容易出现问题

# 保存模型参数:通过调用 pickle 序列化方法实现的
torch.save(the_model, model_path)

# 加载模型参数:在没有 GPU 的设备上需要加上 map_location 参数在 CPU 上执行 
the_model = torch.load(model_path, map_location="cpu")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

在这里插入图片描述

  • 低版本上导入高版本模型Pytorch 1.6 之后的版本默认使用 zip 格式来保存权重,导致这些权重文件无法直接被 1.5 及以下的 pytorch 加载,解决方案要么升级 pytorch,要么使用高版本的环境把权重文件改成非 zip 格式
import torch

# load 模型参数
state_dict = torch.load("xxx.pth")
# 以 unzip 方式 save 模型参数
torch.save(state_dict, "xxx_unzip.pth", _use_new_zipfile_serialization=False)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

2.2、多 GPU 训练后模型的加载

from collections import OrderedDict

# load 模型参数
path_state_dict = "./model_in_multi_gpu.pkl"
state_dict_load = torch.load(path_state_dict, map_location="cpu") # 在没有 GPU 的设备上需要加上 map_location 参数在 CPU 上执行 
print("state_dict_load:\n{}".format(state_dict_load))

# 更改模型的 key 
new_state_dict = OrderedDict()
for k, v in state_dict_load.items():
    namekey = k[7:] if k.startswith('module.') else k  # remove module.
    new_state_dict[namekey] = v
print("new_state_dict:\n{}".format(new_state_dict))

# load 新的模型字典
net.load_state_dict(new_state_dict)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

2.3、断点重训

# 每隔 checkpoint_interval 个 epoch 保存一下模型参数,优化器参数及 epoch
if (epoch + 1) % checkpoint_interval == 0:
    checkpoint = {"model_state_dict": net.state_dict(),
                  "optimizer_state_dict": optimizer.state_dict(),
                  "epoch": epoch
                  "loss": loss,
                  ...}
    path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
    torch.save(checkpoint, path_checkpoint)


# 断点重新训练
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
scheduler.last_epoch = start_epoch  # 需要将学习率下降策略中的 epoch 设置一下
loss = checkpoint['loss']


# 从 start_epoch 开始继续进行训练
for epoch in range(start_epoch + 1, MAX_EPOCH):
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

三、PyTorch 模型微调

一个良好的权值初始化,可以使收敛速度加快,甚至可以获得更好的精度。在实际应用中,我们通常采用一个已经训练模型的模型的权值参数作为我们模型的初始化参数,也称之为 Finetune

3.1、所有层参数采用同一学习率

使用预训练模型进行参数初始化后,采用统一的学习率进行模型训练,optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9)

3.2、固定预训练模型中相关层参数

使用预训练模型进行参数初始化后,冻结特征提取层的参数,可设置预训练模型中的参数层 requires_grad =False

path_pretrained_model = os.path.join(BASEDIR, "resnet18-5c106cde.pth")
state_dict_load = torch.load(path_pretrained_model)
resnet18_ft.load_state_dict(state_dict_load)

# 冻结预训练模型中的参数,不用进行参数更新
for param in resnet18_ft.parameters():
    param.requires_grad = False

# 替换 fc 层,替换后 fc 层的权重和偏置 requires_grad = True
num_ftrs = resnet18_ft.fc.in_features

# 默认采用 nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 进行初始化
resnet18_ft.fc = nn.Linear(num_ftrs, classes)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

3.3、优化器中设置不同的学习率

  • 不同层设置不同的学习率: 利用内存地址作为过滤条件,将需要单独设定的那部分参数,从总的参数中剔除(可以划分成两个、甚至更多的参数组,然后分别进行设置学习率)
# 将原始参数切分成 fc3 层参数和其余参数,为 fc3 层设置更大的学习率
# net.fc3.parameters() 是一个 generator,所以迭代返回其中的 parameter,这里有 weight 和 bias
# 最终返回 fc3.weight 和 fc3.bias 所在内存的地址
ignored_params = list(map(id, net.fc3.parameters()))   

# 将 fc3 层的参数从原始网络参数中剔除,是一个 list,每个元素是一个 Parameter 类
base_params = filter(lambda p: id(p) not in ignored_params, net.parameters())   

# 为 fc3 层设置需要的学习率
optimizer = optim.SGD([
    {'params': base_params, 'lr': lr_init*0},
    {'params': net.fc3.parameters(), 'lr': lr_init*10}],  lr_init, momentum=0.9, weight_decay=1e-4)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

四、PyTorch 模型推理

  • Inference 代码基本步骤:
    • 获取数据与模型
    • 数据变换及预处理,如 RGB → 4D-Tensor
    • 前向传播
    • 输出保存预测结果
  • Inference 阶段注意事项:
    • 确保 model 处于 eval 状态而非 training
    • 设置 torch.no_grad(),减少内存消耗
    • 数据预处理需保持一致, RGB o BGR?
import os
import time
import torch.nn as nn
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.models as models
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# config
vis = True
vis_row = 4

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

inference_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

classes = ["ants", "bees"]


def img_transform(img_rgb, transform=None):
    """
    将数据转换为模型读取的形式
    :param img_rgb: PIL Image
    :param transform: torchvision.transform
    :return: tensor
    """
    if transform is None:
        raise ValueError("找不到transform!必须有transform对img进行处理")

    img_t = transform(img_rgb)
    return img_t


def get_img_name(img_dir, format="jpg"):
    """
    获取文件夹下format格式的文件名
    :param img_dir: str
    :param format: str
    :return: list
    """
    file_names = os.listdir(img_dir)
    img_names = list(filter(lambda x: x.endswith(format), file_names))

    if len(img_names) < 1:
        raise ValueError("{}下找不到{}格式数据".format(img_dir, format))
    return img_names


def get_model(m_path, vis_model=False):
    resnet18 = models.resnet18()
    num_ftrs = resnet18.fc.in_features
    resnet18.fc = nn.Linear(num_ftrs, 2)

    checkpoint = torch.load(m_path)
    resnet18.load_state_dict(checkpoint['model_state_dict'])

    if vis_model:
        from torchsummary import summary
        summary(resnet18, input_size=(3, 224, 224), device="cpu")

    return resnet18


if __name__ == "__main__":
    img_dir = os.path.join("..", "data/hymenoptera_data/val/bees")
    model_path = "./checkpoint_14_epoch.pkl"
    time_total = 0
    img_list, img_pred = list(), list()

    # 1. data
    img_names = get_img_name(img_dir)
    num_img = len(img_names)

    # 2. model
    resnet18 = get_model(model_path, True)
    resnet18.to(device)
    resnet18.eval()

    with torch.no_grad():
        for idx, img_name in enumerate(img_names):
            path_img = os.path.join(img_dir, img_name)

            # step 1/4 : path --> img
            img_rgb = Image.open(path_img).convert('RGB')

            # step 2/4 : img --> tensor
            img_tensor = img_transform(img_rgb, inference_transform)
            img_tensor.unsqueeze_(0)  # 3 维到 4 维
            img_tensor = img_tensor.to(device)

            # step 3/4 : tensor --> vector
            time_tic = time.time()
            outputs = resnet18(img_tensor)  # 推理
            time_toc = time.time()

            # step 4/4 : visualization
            _, pred_int = torch.max(outputs.data, 1)
            pred_str = classes[int(pred_int)]

            if vis:
                img_list.append(img_rgb)
                img_pred.append(pred_str)

                if (idx+1) % (vis_row*vis_row) == 0 or num_img == idx+1:
                    for i in range(len(img_list)):
                        plt.subplot(vis_row, vis_row, i+1).imshow(img_list[i])
                        plt.title("predict:{}".format(img_pred[i]))
                    plt.show()
                    plt.close()
                    img_list, img_pred = list(), list()

            time_s = time_toc-time_tic
            time_total += time_s

            print('{:d}/{:d}: {} {:.3f}s '.format(idx + 1, num_img, img_name, time_s))

    print("\ndevice:{} total time:{:.1f}s mean:{:.3f}s".
          format(device, time_total, time_total/num_img))
    if torch.cuda.is_available():
        print("GPU name:{}".format(torch.cuda.get_device_name()))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129

五、参考资料

1、https://pytorch.org/docs/stable/nn.init.html
2、PyTorch 学习笔记(五):Finetune和各层定制学习率

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/神奇cpp/article/detail/998538
推荐阅读
相关标签
  

闽ICP备14008679号