当前位置:   article > 正文

state_dict使用详解_.state_dict()

.state_dict()

     在PyTorch中,state_dict是一个非常重要的概念,它是一个包含模型参数的字典对象。每个模型的state_dict都包含了该模型的所有参数(权重和偏置等),用于在训练和推理过程中重现模型的内部状态.

      pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如 model的每一层的weights及偏置等等) (注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等) 优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

1. 保存模型参数

        使用torch.save(model.state_dict(), PATH)可以将state_dict保存到指定路径. 常用的保存 state_dict的格式是".pt"或’.pth’的文件,即下面命令的 PATH="./***.pt". 但是文件名字不影响,只是大家大家默认这个名字有辨识度,你取***.sp照样不影响.

torch.save(model.state_dicr(),PATH)  # PATH为存储的位置例如: path/best.pth

2.初始化模型

       即初始化模型的参数, 使用model.load_state_dict(torch.load(PATH))可以重新加载模型。

  1. modle = MyModel(*args, **kwargs)
  2. model.load_state_dict(torch.load(PATH)

3.取出或更新某一层参数

       前面说了state_dict()中的参数是按字典存取,即每个层都有一个key值索引, 所以按照字典规则取出该值即可. 现在假设某层的名字为 conv1.weight.

weight_data = torch.load('./model_state_dict.pt')['conv1.weight']

        修改某一层的值

  1. # 假设 model 是一个已经初始化的模型
  2. # 更改第一层的权重
  3. model.state_dict()['layer1.weight'] = torch.randn(10, 10)

     在训练过程中,state_dict还用于存储梯度信息。在反向传播过程中,PyTorch会通过state_dict来更新模型参数.

4.控制model的某层是否需要梯度求导

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

  1. for param in list(mode.pretrained.parameters()):
  2. param.requires_grad = True

5.手写网络层及state_dict()使用例子

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torchvision
  5. import numpy as mp
  6. import matplotlib.pyplot as plt
  7. import torch.nn.functional as F
  8. #define model
  9. class TheModelClass(nn.Module):
  10. def __init__(self):
  11. super(TheModelClass,self).__init__()
  12. self.conv1=nn.Conv2d(3,6,5)
  13. self.pool=nn.MaxPool2d(2,2)
  14. self.conv2=nn.Conv2d(6,16,5)
  15. self.fc1=nn.Linear(16*5*5,120)
  16. self.fc2=nn.Linear(120,84)
  17. self.fc3=nn.Linear(84,10)
  18. def forward(self,x):
  19. x=self.pool(F.relu(self.conv1(x)))
  20. x=self.pool(F.relu(self.conv2(x)))
  21. x=x.view(-1,16*5*5)
  22. x=F.relu(self.fc1(x))
  23. x=F.relu(self.fc2(x))
  24. x=self.fc3(x)
  25. return x
  26. def main():
  27. # Initialize model
  28. model = TheModelClass()
  29. #Initialize optimizer
  30. optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
  31. #print model's state_dict
  32. print('Model.state_dict:')
  33. for param_tensor in model.state_dict():
  34. #打印 key value字典
  35. print(param_tensor,'\t',model.state_dict()[param_tensor].size())
  36. #print optimizer's state_dict
  37. print('Optimizer,s state_dict:')
  38. for var_name in optimizer.state_dict():
  39. print(var_name,'\t',optimizer.state_dict()[var_name])
  40. if __name__=='__main__':
  41. main()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/103404
推荐阅读
  

闽ICP备14008679号