赞
踩
在pytorch中,可学习的参数例如权重和偏置,都在模型的参数中(model.parameters()),而state_dict就是每一层参数组合而成的字典。
state_dict既然是字典,那么就可以对字典进行保存,更新,载入等操作,要注意的是只有那些具有可学习参数的层和register_buffer(训练时不会更新,保存模型时会被保存)在模型的state_dict中有记载。optimizer也有自己的参数字典。
根据官方代码我们创建一个网络:
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import torch.functional as F
- class Net(nn.Module):
- def __init__(self):
- super(Net, self).__init__()
- self.conv1 = nn.Conv2d(3, 6, 5)
- self.pool = nn.MaxPool2d(2, 2)
- self.conv2 = nn.Conv2d(6, 16, 5)
- self.fc1 = nn.Linear(16 * 5 * 5, 120)
- self.fc2 = nn.Linear(120, 84)
- self.fc3 = nn.Linear(84, 10)
-
- def forward(self, x):
- x = self.pool(F.relu(self.conv1(x)))
- x = self.pool(F.relu(self.conv2(x)))
- x = x.view(-1, 16 * 5 * 5)
- x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
- x = self.fc3(x)
- return x
-
- net = Net()
- print(net)
- optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
我们打印模型的state_dict:
①:我们首先遍历模型net的state_dict,state_dict中包含的就是网络各个层的权重和偏置,在net.state_dict()[param_tensor].size()中,因为state_dict是字典,我们通过字典的键来获得对应的值。(这里的conv1权重大小为(6,3,5,5)是因为卷积核的大小是(5,5))
②:我们遍历优化器,包含了优化器的状态和超参数,
- # Print model's state_dict
- print("Model's state_dict:")
- for param_tensor in net.state_dict():
- print(param_tensor, "\t", net.state_dict()[param_tensor].size())
-
- print()
-
- # Print optimizer's state_dict
- print("Optimizer's state_dict:")
- for key,value in optimizer.state_dict().items():
- print(key, "\t", value)
- Model's state_dict:
- conv1.weight torch.Size([6, 3, 5, 5])
- conv1.bias torch.Size([6])
- conv2.weight torch.Size([16, 6, 5, 5])
- conv2.bias torch.Size([16])
- fc1.weight torch.Size([120, 400])
- fc1.bias torch.Size([120])
- fc2.weight torch.Size([84, 120])
- fc2.bias torch.Size([84])
- fc3.weight torch.Size([10, 84])
- fc3.bias torch.Size([10])
- Optimizer's state_dict:
- state {}
- param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]
state_dict的应用场景:
我们在训练前使用预训练权重,一般是一个.pth文件,里面就是一个字典,我们载入resnet34的权重:
- import torch
- pthfile = r'D:/AI预训练权重/train_r34_NBt1D.pth' #faster_rcnn_ckpt.pth
- net = torch.load(pthfile,map_location=torch.device('cpu'))
- print(type(net))
- print(len(net))
- for k in net.keys():
- print(k)
- print(net["state_dict"])
- for key,value in net["state_dict"].items():
- print(key,value,sep=" ")
- <class 'dict'>
- 1
- state_dict
- OrderedDict([('encoder.conv1.weight', tensor([[[[ 1.5516e-02, 5.2939e-03, 2.8082e-03, ..., -6.3492e-02,
- -9.9119e-03, 6.2728e-02],
- [ 7.7608e-03, 4.9472e-02, 5.4932e-02, ..., -1.7819e-01,
- -1.2713e-01, 5.3156e-03],
- [-9.4686e-03, 4.9467e-02, 2.2223e-01, ..., -1.0941e-01]),
- ('encoder.bn1.weight', tensor([4.1093e-01, 4.1710e-01, 4.3806e-08, 2.7257e-01, 3.0985e-01, 4.4599e-01,
- 3.2788e-01, 3.9957e-01, 3.8334e-01, 6.7823e-07, 7.3982e-01, 1.5724e-01,
结果net的state_dict就是一个有序字典,里面包含了每一层名称以及对应的权重。
那如何载入部分权重呢?
我们打印一下字典中的key:
resnet34中conv1即7x7卷积层的key,包括卷积权重,bn。
- conv1.weight
- bn1.running_mean
- bn1.running_var
- bn1.weight
- bn1.bias
接着第一层包含三层,每层两个3x3,64的卷积。
- layer1.0.conv1.weight
- layer1.0.bn1.running_mean
- layer1.0.bn1.running_var
- layer1.0.bn1.weight
- layer1.0.bn1.bias
- layer1.0.conv2.weight
- layer1.0.bn2.running_mean
- layer1.0.bn2.running_var
- layer1.0.bn2.weight
- layer1.0.bn2.bias
- layer1.1.conv1.weight
- layer1.1.bn1.running_mean
- layer1.1.bn1.running_var
- layer1.1.bn1.weight
- layer1.1.bn1.bias
- layer1.1.conv2.weight
- layer1.1.bn2.running_mean
- layer1.1.bn2.running_var
- layer1.1.bn2.weight
- layer1.1.bn2.bias
- layer1.2.conv1.weight
- layer1.2.bn1.running_mean
- layer1.2.bn1.running_var
- layer1.2.bn1.weight
- layer1.2.bn1.bias
- layer1.2.conv2.weight
- layer1.2.bn2.running_mean
- layer1.2.bn2.running_var
- layer1.2.bn2.weight
- layer1.2.bn2.bias
第二层包含4层,每次包含两个3x3,128的卷积。注意这里有一个downsample,是因为采用了basicblock中的下采样,即采用1x1,128卷积层。
- layer2.0.conv1.weight
- layer2.0.bn1.running_mean
- layer2.0.bn1.running_var
- layer2.0.bn1.weight
- layer2.0.bn1.bias
- layer2.0.conv2.weight
- layer2.0.bn2.running_mean
- layer2.0.bn2.running_var
- layer2.0.bn2.weight
- layer2.0.bn2.bias
- layer2.0.downsample.0.weight
- layer2.0.downsample.1.running_mean
- layer2.0.downsample.1.running_var
- layer2.0.downsample.1.weight
- layer2.0.downsample.1.bias
- layer2.1.conv1.weight
- layer2.1.bn1.running_mean
- layer2.1.bn1.running_var
- layer2.1.bn1.weight
- layer2.1.bn1.bias
- layer2.1.conv2.weight
- layer2.1.bn2.running_mean
- layer2.1.bn2.running_var
- layer2.1.bn2.weight
- layer2.1.bn2.bias
- layer2.2.conv1.weight
- layer2.2.bn1.running_mean
- layer2.2.bn1.running_var
- layer2.2.bn1.weight
- layer2.2.bn1.bias
- layer2.2.conv2.weight
- layer2.2.bn2.running_mean
- layer2.2.bn2.running_var
- layer2.2.bn2.weight
- layer2.2.bn2.bias
- layer2.3.conv1.weight
- layer2.3.bn1.running_mean
- layer2.3.bn1.running_var
- layer2.3.bn1.weight
- layer2.3.bn1.bias
- layer2.3.conv2.weight
- layer2.3.bn2.running_mean
- layer2.3.bn2.running_var
- layer2.3.bn2.weight
- layer2.3.bn2.bias
第三层包含六层,每层有两个3x3,256的卷积。
- layer3.0.conv1.weight
- layer3.0.bn1.running_mean
- layer3.0.bn1.running_var
- layer3.0.bn1.weight
- layer3.0.bn1.bias
- layer3.0.conv2.weight
- layer3.0.bn2.running_mean
- layer3.0.bn2.running_var
- layer3.0.bn2.weight
- layer3.0.bn2.bias
- layer3.0.downsample.0.weight
- layer3.0.downsample.1.running_mean
- layer3.0.downsample.1.running_var
- layer3.0.downsample.1.weight
- layer3.0.downsample.1.bias
- layer3.1.conv1.weight
- layer3.1.bn1.running_mean
- layer3.1.bn1.running_var
- layer3.1.bn1.weight
- layer3.1.bn1.bias
- layer3.1.conv2.weight
- layer3.1.bn2.running_mean
- layer3.1.bn2.running_var
- layer3.1.bn2.weight
- layer3.1.bn2.bias
- layer3.2.conv1.weight
- layer3.2.bn1.running_mean
- layer3.2.bn1.running_var
- layer3.2.bn1.weight
- layer3.2.bn1.bias
- layer3.2.conv2.weight
- layer3.2.bn2.running_mean
- layer3.2.bn2.running_var
- layer3.2.bn2.weight
- layer3.2.bn2.bias
- layer3.3.conv1.weight
- layer3.3.bn1.running_mean
- layer3.3.bn1.running_var
- layer3.3.bn1.weight
- layer3.3.bn1.bias
- layer3.3.conv2.weight
- layer3.3.bn2.running_mean
- layer3.3.bn2.running_var
- layer3.3.bn2.weight
- layer3.3.bn2.bias
- layer3.4.conv1.weight
- layer3.4.bn1.running_mean
- layer3.4.bn1.running_var
- layer3.4.bn1.weight
- layer3.4.bn1.bias
- layer3.4.conv2.weight
- layer3.4.bn2.running_mean
- layer3.4.bn2.running_var
- layer3.4.bn2.weight
- layer3.4.bn2.bias
- layer3.5.conv1.weight
- layer3.5.bn1.running_mean
- layer3.5.bn1.running_var
- layer3.5.bn1.weight
- layer3.5.bn1.bias
- layer3.5.conv2.weight
- layer3.5.bn2.running_mean
- layer3.5.bn2.running_var
- layer3.5.bn2.weight
- layer3.5.bn2.bias
第四层包含三层,每层3x3,512的卷积。以及最后的全连接层。
- layer4.0.conv1.weight
- layer4.0.bn1.running_mean
- layer4.0.bn1.running_var
- layer4.0.bn1.weight
- layer4.0.bn1.bias
- layer4.0.conv2.weight
- layer4.0.bn2.running_mean
- layer4.0.bn2.running_var
- layer4.0.bn2.weight
- layer4.0.bn2.bias
- layer4.0.downsample.0.weight
- layer4.0.downsample.1.running_mean
- layer4.0.downsample.1.running_var
- layer4.0.downsample.1.weight
- layer4.0.downsample.1.bias
- layer4.1.conv1.weight
- layer4.1.bn1.running_mean
- layer4.1.bn1.running_var
- layer4.1.bn1.weight
- layer4.1.bn1.bias
- layer4.1.conv2.weight
- layer4.1.bn2.running_mean
- layer4.1.bn2.running_var
- layer4.1.bn2.weight
- layer4.1.bn2.bias
- layer4.2.conv1.weight
- layer4.2.bn1.running_mean
- layer4.2.bn1.running_var
- layer4.2.bn1.weight
- layer4.2.bn1.bias
- layer4.2.conv2.weight
- layer4.2.bn2.running_mean
- layer4.2.bn2.running_var
- layer4.2.bn2.weight
- layer4.2.bn2.bias
- fc.weight
- fc.bias
载入权重后,我们遍历权重,当遇到“fc”键时,将其加入空列表,然后将k的每一层与del_key每一层比较,如果相同,删除掉权重中相对应的值,删除后再载入我们自己的权重。代码参考b导霹雳吧啦。
- del_key = []
- for key, _ in pre_weights.items():
- print(key)
- if "fc" in key:
- del_key.append(key)
- print(del_key)
- for key in del_key:
- del pre_weights[key]
- print(pre_weights)
- missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
重新打印,我们发现最后一层fc层全部被删除掉了。
- ('layer4.2.bn2.bias', Parameter containing:
- tensor([ 0.1216, 0.1289, 0.1926, 0.1332, 0.0978, 0.1507, 0.1391, 0.1332,
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。