当前位置:   article > 正文

pytorch状态字典state_dict, load_state_dict torch.load 以及eval,作用,保存和加载的使用_torch 字典

torch 字典

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

 

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)
 
 
  • 1

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用


 
 
  1. model = TheModelClass(*args, **kwargs)
  2. model.load_state_dict(torch.load(PATH))
  3. model.eval()
  • 1

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

-------------------------------------------------------------------------------------------------------------------------------

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

    torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()

    备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

-----------

2.1)保存整个model的状态,用以下命令

    torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

          # Model class must be defined somewhere

    model = torch.load(PATH)

    model.eval()

--------------------------------------------------------------------------------------------------------------------------------------

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

----------------------------------------------------------------------------------------------------------------------

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
 
 
  • 1

--------------------------------------------------------------------------------------------

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)


 
 
  1. for param in list(model.pretrained.parameters()):
  2. param.requires_grad = False
  • 1

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

 

---------------------------------------------------------------------------------------------

全部测试代码:


 
 
  1. #-*-coding:utf-8-*-
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import torch.optim as optim
  6. # define model
  7. class TheModelClass(nn.Module):
  8. def __init__(self):
  9. super(TheModelClass,self).__init__()
  10. self.conv1 = nn.Conv2d( 3, 6, 5)
  11. self.pool = nn.MaxPool2d( 2, 2)
  12. self.conv2 = nn.Conv2d( 6, 16, 5)
  13. self.fc1 = nn.Linear( 16* 5* 5, 120)
  14. self.fc2 = nn.Linear( 120, 84)
  15. self.fc3 = nn.Linear( 84, 10)
  16. def forward(self,x):
  17. x = self.pool(F.relu(self.conv1(x)))
  18. x = self.pool(F.relu(self.conv2(x)))
  19. x = x.view( -1, 16* 5* 5)
  20. x = F.relu(self.fc1(x))
  21. x = F.relu(self.fc2(x))
  22. x = self.fc3(x)
  23. return x
  24. # initial model
  25. model = TheModelClass()
  26. #initialize the optimizer
  27. optimizer = optim.SGD(model.parameters(),lr= 0.001,momentum= 0.9)
  28. # print the model's state_dict
  29. print( "model's state_dict:")
  30. for param_tensor in model.state_dict():
  31. print(param_tensor, '\t',model.state_dict()[param_tensor].size())
  32. print( "\noptimizer's state_dict")
  33. for var_name in optimizer.state_dict():
  34. print(var_name, '\t',optimizer.state_dict()[var_name])
  35. print( "\nprint particular param")
  36. print( '\n',model.conv1.weight.size())
  37. print( '\n',model.conv1.weight)
  38. print( "------------------------------------")
  39. torch.save(model.state_dict(), './model_state_dict.pt')
  40. # model_2 = TheModelClass()
  41. # model_2.load_state_dict(torch.load('./model_state_dict'))
  42. # model.eval()
  43. # print('\n',model_2.conv1.weight)
  44. # print((model_2.conv1.weight == model.conv1.weight).size())
  45. ## 仅仅加载某一层的参数
  46. conv1_weight_state = torch.load( './model_state_dict.pt')[ 'conv1.weight']
  47. print(conv1_weight_state==model.conv1.weight)
  48. model_2 = TheModelClass()
  49. model_2.load_state_dict(torch.load( './model_state_dict.pt'))
  50. model_2.conv1.requires_grad= False
  51. print(model_2.conv1.requires_grad)
  52. print(model_2.conv1.bias.requires_grad)
  • 1

 

 

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/一键难忘520/article/detail/785875
推荐阅读
相关标签
  

闽ICP备14008679号