当前位置:   article > 正文

Pytorch基础:torch.load_state_dict()方法在加载时不会检查类型

Pytorch基础:torch.load_state_dict()方法在加载时不会检查类型

相关阅读

Pytorch基础icon-default.png?t=N7T8https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


        笔者在使用torch.nn.module的load_state_dict中出现了一个问题,一个被注册的张量在加载后居然没有变化,一开始以为是加载出现了问题,但发现其他参数加载成功,思索后发现是注册的张量的类型是整型而checkpoint中保存为浮点数类型,恰好注册时的默认值给的是0,而checkpoint中的浮点数又在0到1之间,因此出现了这个令人困惑的bug。

        下面首先复现这个bug。

  1. import torch
  2. import torch.nn as nn
  3. # 定义一个简单的线性模型,参数类型为整数
  4. class SimpleModel(nn.Module):
  5. def __init__(self):
  6. super(SimpleModel, self).__init__()
  7. self.register_buffer('test', torch.tensor(0)) # 注册一个整型张量
  8. # 创建一个简单模型实例
  9. model = SimpleModel()
  10. # 创建一个浮点数作为参数
  11. float_parameter = torch.tensor(0.6)
  12. # 将注册名指向另一个浮点型张量
  13. model.test = float_parameter
  14. # 保存模型
  15. torch.save(model.state_dict(), 'model.pth')
  16. # 直接使用原模型加载
  17. checkpoint = torch.load('model.pth')
  18. model.load_state_dict(checkpoint)
  19. # 打印加载后的参数
  20. print(model.test)
  21. # 直接使用新模型加载
  22. model_1 = SimpleModel()
  23. model_1.load_state_dict(checkpoint)
  24. # 打印加载后的参数
  25. print(model_1.test)
  1. 输出:
  2. tensor(0.6000)
  3. tensor(0)

        可以看到,当模型中注册的名字(test),指向了一个类型不符的张量后,并不会导致浮点型张量被截断为整型,这是因为此处是直接使用赋值号=,使名字指向了另一个张量。

        但使用load_state_dict()方法与使用赋值号是不同的,load_state_dict()方法的实现中,调用了_load_from_state_dict()方法,其中调用了copy_()方法,进行了原位(in-place)数据替换,这可能会进行截断,下面是原位替换的一个例子。

  1. import torch
  2. # 创建两个张量
  3. a = torch.tensor([[1, 2], [3, 4]])
  4. b = torch.tensor([[5.1, 6.1], [7.1, 8.1]])
  5. # 查看张量对象的id
  6. print(id(a))
  7. print(id(b))
  8. # 查看底层存储的内存地址
  9. print(a.storage().data_ptr())
  10. print(b.storage().data_ptr())
  11. # 将张量 b 中的值复制到张量 a 中
  12. a.copy_(b)
  13. # 打印复制后的结果
  14. print(a)
  15. # 查看张量对象的id
  16. print(id(a))
  17. print(id(b))
  18. # 查看底层存储的内存地址
  19. print(a.storage().data_ptr())
  20. print(b.storage().data_ptr())
  1. 输出:
  2. 2604425272672
  3. 2604426953808
  4. 2604511348096
  5. 2602930352832
  6. tensor([[5, 6],
  7. [7, 8]])
  8. 2604425272672
  9. 2604426953808
  10. 2604511348096
  11. 2602930352832

        在保存了模型的状态字典后,使用load_state_dict()方法加载后,也不会有任何截断问题,因为对于原模型而言,名字test指向的是一个浮点型张量,此时原位替换,类型吻合。但是对于一个新的模型,此时的test指向的是一个整型张量,此时原位替换,会发生截断。

        因此,在注册一个张量时,需要确保其在注册时和保存时的类型吻合,此处除了指形状,还有类型,否则可能会出现意想不到的bug。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/739390
推荐阅读
相关标签
  

闽ICP备14008679号