当前位置:   article > 正文

model.parameters(),model.state_dict(),model .load_state_dict()以及torch.load()_model.load_state_dict(torch.load(pth))

model.load_state_dict(torch.load(pth))

一.model.parameters()与model.state_dict()

model.parameters()model.state_dict()都Pytorch中用于查看网络参数的方法

一般来说,前者多见于优化器的初始化,例如:

后者多见于模型的保存,如:

当我们对网络调参或者查看网络的参数是否具有可复现性时,可能会查看网络的参数

  1. pretrained_dict = torch.load(yolov4conv137weight)
  2. model_dict = _model.state_dict() #查看模型的权重和biass系数
  3. pretrained_dict = {k1: v for (k, v), k1 in zip(pretrained_dict.items(), model_dict)}
  4. model_dict.update(pretrained_dict) #更新model网络模型的参数的权值和biass,这相当于是一个浅拷贝,对这个更新改变会更改模型的权重和biass

model.state_dict()其实返回的是一个OrderDict,存储了网络结构的名字和对应的参数。

例子:

  1. #encoding:utf-8
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import torchvision
  6. import numpy as mp
  7. import matplotlib.pyplot as plt
  8. import torch.nn.functional as F
  9. #define model
  10. class TheModelClass(nn.Module):
  11. def __init__(self):
  12. super(TheModelClass,self).__init__()
  13. self.conv1=nn.Conv2d(3,6,5)
  14. self.pool=nn.MaxPool2d(2,2)
  15. self.conv2=nn.Conv2d(6,16,5)
  16. self.fc1=nn.Linear(16*5*5,120)
  17. self.fc2=nn.Linear(120,84)
  18. self.fc3=nn.Linear(84,10)
  19. def forward(self,x):
  20. x=self.pool(F.relu(self.conv1(x)))
  21. x=self.pool(F.relu(self.conv2(x)))
  22. x=x.view(-1,16*5*5)
  23. x=F.relu(self.fc1(x))
  24. x=F.relu(self.fc2(x))
  25. x=self.fc3(x)
  26. return x
  27. def main():
  28. # Initialize model
  29. model = TheModelClass()
  30. #Initialize optimizer
  31. optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
  32. #print model's state_dict
  33. print('Model.state_dict:')
  34. for param_tensor in model.state_dict():
  35. #打印 key value字典
  36. print(param_tensor,'\t',model.state_dict()[param_tensor].size())
  37. #print optimizer's state_dict
  38. print('Optimizer,s state_dict:')
  39. for var_name in optimizer.state_dict():
  40. print(var_name,'\t',optimizer.state_dict()[var_name])
  41. if __name__=='__main__':
  42. main()

具体的输出结果如下:可以很清晰的观测到state_dict中存放的key和value的值

  1. Model.state_dict:
  2. conv1.weight torch.Size([6, 3, 5, 5])
  3. conv1.bias torch.Size([6])
  4. conv2.weight torch.Size([16, 6, 5, 5])
  5. conv2.bias torch.Size([16])
  6. fc1.weight torch.Size([120, 400])
  7. fc1.bias torch.Size([120])
  8. fc2.weight torch.Size([84, 120])
  9. fc2.bias torch.Size([84])
  10. fc3.weight torch.Size([10, 84])
  11. fc3.bias torch.Size([10])
  12. Optimizer,s state_dict:
  13. state {}
  14. param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [367949288, 367949432, 376459056, 381121808, 381121952, 381122024, 381121880, 381122168, 381122096, 381122312]}]

二.torch.load()和load_state_dict()

load_state_dict(state_dict, strict=True)

从 state_dict 中复制参数和缓冲区到 Module 及其子类中 

state_dict:包含参数和缓冲区的 Module 状态字典

strict:默认 True,是否严格匹配 state_dict 的键值和 Module.state_dict()的键值
 

  1. model = nn.Sequential(self.down1, self.down2, self.down3, self.down4, self.down5, self.neek)
  2. pretrained_dict = torch.load(yolov4conv137weight) #加载已经训练好的模型参数
  3. model_dict = model.state_dict() #查看权重和偏重
  4. # 1. filter out unnecessary keys
  5. pretrained_dict = {k1: v for (k, v), k1 in zip(pretrained_dict.items(), model_dict)}
  6. # 2. overwrite entries in the existing state dict
  7. model_dict.update(pretrained_dict) #更新已有的模型的权重和偏重
  8. model.load_state_dict(model_dict) #将更新后的参数重新加载至网络模型中

官方推荐的方法,只保存和恢复模型中的参数

  1. # save
  2. torch.save(model.state_dict(), PATH)
  3. # load
  4. model = MyModel(*args, **kwargs)
  5. model.load_state_dict(torch.load(PATH))
  6. model.eval()

torch.load("path路径")表示加载已经训练好的模型

而model.load_state_dict(torch.load(PATH))表示将训练好的模型参数重新加载至网络模型中

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/103496
推荐阅读
相关标签
  

闽ICP备14008679号