赞
踩
没想到还有这样的坑。
在远程服务器使用PyTorch训练好模型后,使用torch.save(local_model, file_name, _use_new_zipfile_serialization=False)
将模型保存,原以为这种保存方法可以将模型结构和参数一起保存。但是本地torch.load
报错:ModuleNotFoundError: No module named '***'
(并没有显式导入模型结构)。
经过查阅,虽然是保存了模型结构,但仍然需要将“目录结构得和保存时一模一样”,具体来说就是模型结构的定义文件需要和训练时一致,其他部分文件不需要。
原因(转自pytorch加载模型遇到问题ModuleNotFoundError: No module named ‘models‘):训练时采用第二种方式保存整个模型以便于在其他地方调用测试,而该方式保存模型会使序列化的数据保存到特定的类,并且依赖该类文件的特定的目录结构,该路径在加载时使用。因此,在上面项目中调用其他地方保存的模型时由于缺少models路径而找不到models模块。
还不能理解这种设计的原因。我使用torch.save保存模型而非模型参数的原因就在于我希望我的模型可以在新的环境下方便的使用,但是…(网上有评论说是出于安全的考虑,并不认同)
解决方法:使用torch.jit.save
保存模型为TorchScript,已测试可行:
#TorchScript是PyTorch模型(nn.Module的子类)的中间表示,可以在高性能环境(例如C )中运行
# 模型保存,input为和模型输入维度相同
model = torch.jit.trace(model, input)
torch.jit.save(model,save_path)
# 模型载入
model = torch.jit.load(save_path)
注:torch1.6及之后开始,Pytorch就更换了保存模型文件的方法,低版本读取高版本保存的模型时会报错。在保存时使用_use_new_zipfile_serialization=False
,指定不使用新的zip保存方法。
附:PyTorch JIT简介
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。