当前位置:   article > 正文

pytorch中什么是state_dict?如何载入部分权重?_net.state_dict()

net.state_dict()

在pytorch中,可学习的参数例如权重和偏置,都在模型的参数中(model.parameters()),而state_dict就是每一层参数组合而成的字典。

state_dict既然是字典,那么就可以对字典进行保存,更新,载入等操作,要注意的是只有那些具有可学习参数的层和register_buffer(训练时不会更新,保存模型时会被保存)在模型的state_dict中有记载。optimizer也有自己的参数字典。

根据官方代码我们创建一个网络:

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torch.functional as F
  5. class Net(nn.Module):
  6. def __init__(self):
  7. super(Net, self).__init__()
  8. self.conv1 = nn.Conv2d(3, 6, 5)
  9. self.pool = nn.MaxPool2d(2, 2)
  10. self.conv2 = nn.Conv2d(6, 16, 5)
  11. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  12. self.fc2 = nn.Linear(120, 84)
  13. self.fc3 = nn.Linear(84, 10)
  14. def forward(self, x):
  15. x = self.pool(F.relu(self.conv1(x)))
  16. x = self.pool(F.relu(self.conv2(x)))
  17. x = x.view(-1, 16 * 5 * 5)
  18. x = F.relu(self.fc1(x))
  19. x = F.relu(self.fc2(x))
  20. x = self.fc3(x)
  21. return x
  22. net = Net()
  23. print(net)
  24. optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

我们打印模型的state_dict:

①:我们首先遍历模型net的state_dict,state_dict中包含的就是网络各个层的权重和偏置,在net.state_dict()[param_tensor].size()中,因为state_dict是字典,我们通过字典的键来获得对应的值。(这里的conv1权重大小为(6,3,5,5)是因为卷积核的大小是(5,5))

②:我们遍历优化器,包含了优化器的状态和超参数,

  1. # Print model's state_dict
  2. print("Model's state_dict:")
  3. for param_tensor in net.state_dict():
  4. print(param_tensor, "\t", net.state_dict()[param_tensor].size())
  5. print()
  6. # Print optimizer's state_dict
  7. print("Optimizer's state_dict:")
  8. for key,value in optimizer.state_dict().items():
  9. print(key, "\t", value)
  1. Model's state_dict:
  2. conv1.weight torch.Size([6, 3, 5, 5])
  3. conv1.bias torch.Size([6])
  4. conv2.weight torch.Size([16, 6, 5, 5])
  5. conv2.bias torch.Size([16])
  6. fc1.weight torch.Size([120, 400])
  7. fc1.bias torch.Size([120])
  8. fc2.weight torch.Size([84, 120])
  9. fc2.bias torch.Size([84])
  10. fc3.weight torch.Size([10, 84])
  11. fc3.bias torch.Size([10])
  12. Optimizer's state_dict:
  13. state {}
  14. param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]

state_dict的应用场景:

我们在训练前使用预训练权重,一般是一个.pth文件,里面就是一个字典,我们载入resnet34的权重:

  1. import torch
  2. pthfile = r'D:/AI预训练权重/train_r34_NBt1D.pth' #faster_rcnn_ckpt.pth
  3. net = torch.load(pthfile,map_location=torch.device('cpu'))
  4. print(type(net))
  5. print(len(net))
  6. for k in net.keys():
  7. print(k)
  8. print(net["state_dict"])
  9. for key,value in net["state_dict"].items():
  10. print(key,value,sep=" ")
  1. <class 'dict'>
  2. 1
  3. state_dict
  4. OrderedDict([('encoder.conv1.weight', tensor([[[[ 1.5516e-02, 5.2939e-03, 2.8082e-03, ..., -6.3492e-02,
  5. -9.9119e-03, 6.2728e-02],
  6. [ 7.7608e-03, 4.9472e-02, 5.4932e-02, ..., -1.7819e-01,
  7. -1.2713e-01, 5.3156e-03],
  8. [-9.4686e-03, 4.9467e-02, 2.2223e-01, ..., -1.0941e-01]),
  9. ('encoder.bn1.weight', tensor([4.1093e-01, 4.1710e-01, 4.3806e-08, 2.7257e-01, 3.0985e-01, 4.4599e-01,
  10. 3.2788e-01, 3.9957e-01, 3.8334e-01, 6.7823e-07, 7.3982e-01, 1.5724e-01,

结果net的state_dict就是一个有序字典,里面包含了每一层名称以及对应的权重。

那如何载入部分权重呢?

我们打印一下字典中的key:

resnet34中conv1即7x7卷积层的key,包括卷积权重,bn。

  1. conv1.weight
  2. bn1.running_mean
  3. bn1.running_var
  4. bn1.weight
  5. bn1.bias

接着第一层包含三层,每层两个3x3,64的卷积。

  1. layer1.0.conv1.weight
  2. layer1.0.bn1.running_mean
  3. layer1.0.bn1.running_var
  4. layer1.0.bn1.weight
  5. layer1.0.bn1.bias
  6. layer1.0.conv2.weight
  7. layer1.0.bn2.running_mean
  8. layer1.0.bn2.running_var
  9. layer1.0.bn2.weight
  10. layer1.0.bn2.bias
  11. layer1.1.conv1.weight
  12. layer1.1.bn1.running_mean
  13. layer1.1.bn1.running_var
  14. layer1.1.bn1.weight
  15. layer1.1.bn1.bias
  16. layer1.1.conv2.weight
  17. layer1.1.bn2.running_mean
  18. layer1.1.bn2.running_var
  19. layer1.1.bn2.weight
  20. layer1.1.bn2.bias
  21. layer1.2.conv1.weight
  22. layer1.2.bn1.running_mean
  23. layer1.2.bn1.running_var
  24. layer1.2.bn1.weight
  25. layer1.2.bn1.bias
  26. layer1.2.conv2.weight
  27. layer1.2.bn2.running_mean
  28. layer1.2.bn2.running_var
  29. layer1.2.bn2.weight
  30. layer1.2.bn2.bias

第二层包含4层,每次包含两个3x3,128的卷积。注意这里有一个downsample,是因为采用了basicblock中的下采样,即采用1x1,128卷积层。

  1. layer2.0.conv1.weight
  2. layer2.0.bn1.running_mean
  3. layer2.0.bn1.running_var
  4. layer2.0.bn1.weight
  5. layer2.0.bn1.bias
  6. layer2.0.conv2.weight
  7. layer2.0.bn2.running_mean
  8. layer2.0.bn2.running_var
  9. layer2.0.bn2.weight
  10. layer2.0.bn2.bias
  11. layer2.0.downsample.0.weight
  12. layer2.0.downsample.1.running_mean
  13. layer2.0.downsample.1.running_var
  14. layer2.0.downsample.1.weight
  15. layer2.0.downsample.1.bias
  16. layer2.1.conv1.weight
  17. layer2.1.bn1.running_mean
  18. layer2.1.bn1.running_var
  19. layer2.1.bn1.weight
  20. layer2.1.bn1.bias
  21. layer2.1.conv2.weight
  22. layer2.1.bn2.running_mean
  23. layer2.1.bn2.running_var
  24. layer2.1.bn2.weight
  25. layer2.1.bn2.bias
  26. layer2.2.conv1.weight
  27. layer2.2.bn1.running_mean
  28. layer2.2.bn1.running_var
  29. layer2.2.bn1.weight
  30. layer2.2.bn1.bias
  31. layer2.2.conv2.weight
  32. layer2.2.bn2.running_mean
  33. layer2.2.bn2.running_var
  34. layer2.2.bn2.weight
  35. layer2.2.bn2.bias
  36. layer2.3.conv1.weight
  37. layer2.3.bn1.running_mean
  38. layer2.3.bn1.running_var
  39. layer2.3.bn1.weight
  40. layer2.3.bn1.bias
  41. layer2.3.conv2.weight
  42. layer2.3.bn2.running_mean
  43. layer2.3.bn2.running_var
  44. layer2.3.bn2.weight
  45. layer2.3.bn2.bias

第三层包含六层,每层有两个3x3,256的卷积。

  1. layer3.0.conv1.weight
  2. layer3.0.bn1.running_mean
  3. layer3.0.bn1.running_var
  4. layer3.0.bn1.weight
  5. layer3.0.bn1.bias
  6. layer3.0.conv2.weight
  7. layer3.0.bn2.running_mean
  8. layer3.0.bn2.running_var
  9. layer3.0.bn2.weight
  10. layer3.0.bn2.bias
  11. layer3.0.downsample.0.weight
  12. layer3.0.downsample.1.running_mean
  13. layer3.0.downsample.1.running_var
  14. layer3.0.downsample.1.weight
  15. layer3.0.downsample.1.bias
  16. layer3.1.conv1.weight
  17. layer3.1.bn1.running_mean
  18. layer3.1.bn1.running_var
  19. layer3.1.bn1.weight
  20. layer3.1.bn1.bias
  21. layer3.1.conv2.weight
  22. layer3.1.bn2.running_mean
  23. layer3.1.bn2.running_var
  24. layer3.1.bn2.weight
  25. layer3.1.bn2.bias
  26. layer3.2.conv1.weight
  27. layer3.2.bn1.running_mean
  28. layer3.2.bn1.running_var
  29. layer3.2.bn1.weight
  30. layer3.2.bn1.bias
  31. layer3.2.conv2.weight
  32. layer3.2.bn2.running_mean
  33. layer3.2.bn2.running_var
  34. layer3.2.bn2.weight
  35. layer3.2.bn2.bias
  36. layer3.3.conv1.weight
  37. layer3.3.bn1.running_mean
  38. layer3.3.bn1.running_var
  39. layer3.3.bn1.weight
  40. layer3.3.bn1.bias
  41. layer3.3.conv2.weight
  42. layer3.3.bn2.running_mean
  43. layer3.3.bn2.running_var
  44. layer3.3.bn2.weight
  45. layer3.3.bn2.bias
  46. layer3.4.conv1.weight
  47. layer3.4.bn1.running_mean
  48. layer3.4.bn1.running_var
  49. layer3.4.bn1.weight
  50. layer3.4.bn1.bias
  51. layer3.4.conv2.weight
  52. layer3.4.bn2.running_mean
  53. layer3.4.bn2.running_var
  54. layer3.4.bn2.weight
  55. layer3.4.bn2.bias
  56. layer3.5.conv1.weight
  57. layer3.5.bn1.running_mean
  58. layer3.5.bn1.running_var
  59. layer3.5.bn1.weight
  60. layer3.5.bn1.bias
  61. layer3.5.conv2.weight
  62. layer3.5.bn2.running_mean
  63. layer3.5.bn2.running_var
  64. layer3.5.bn2.weight
  65. layer3.5.bn2.bias

第四层包含三层,每层3x3,512的卷积。以及最后的全连接层。

  1. layer4.0.conv1.weight
  2. layer4.0.bn1.running_mean
  3. layer4.0.bn1.running_var
  4. layer4.0.bn1.weight
  5. layer4.0.bn1.bias
  6. layer4.0.conv2.weight
  7. layer4.0.bn2.running_mean
  8. layer4.0.bn2.running_var
  9. layer4.0.bn2.weight
  10. layer4.0.bn2.bias
  11. layer4.0.downsample.0.weight
  12. layer4.0.downsample.1.running_mean
  13. layer4.0.downsample.1.running_var
  14. layer4.0.downsample.1.weight
  15. layer4.0.downsample.1.bias
  16. layer4.1.conv1.weight
  17. layer4.1.bn1.running_mean
  18. layer4.1.bn1.running_var
  19. layer4.1.bn1.weight
  20. layer4.1.bn1.bias
  21. layer4.1.conv2.weight
  22. layer4.1.bn2.running_mean
  23. layer4.1.bn2.running_var
  24. layer4.1.bn2.weight
  25. layer4.1.bn2.bias
  26. layer4.2.conv1.weight
  27. layer4.2.bn1.running_mean
  28. layer4.2.bn1.running_var
  29. layer4.2.bn1.weight
  30. layer4.2.bn1.bias
  31. layer4.2.conv2.weight
  32. layer4.2.bn2.running_mean
  33. layer4.2.bn2.running_var
  34. layer4.2.bn2.weight
  35. layer4.2.bn2.bias
  36. fc.weight
  37. fc.bias

载入权重后,我们遍历权重,当遇到“fc”键时,将其加入空列表,然后将k的每一层与del_key每一层比较,如果相同,删除掉权重中相对应的值,删除后再载入我们自己的权重。代码参考b导霹雳吧啦

  1. del_key = []
  2. for key, _ in pre_weights.items():
  3. print(key)
  4. if "fc" in key:
  5. del_key.append(key)
  6. print(del_key)
  7. for key in del_key:
  8. del pre_weights[key]
  9. print(pre_weights)
  10. missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)

重新打印,我们发现最后一层fc层全部被删除掉了。

  1. ('layer4.2.bn2.bias', Parameter containing:
  2. tensor([ 0.1216, 0.1289, 0.1926, 0.1332, 0.0978, 0.1507, 0.1391, 0.1332,

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/406978
推荐阅读
相关标签
  

闽ICP备14008679号