赞
踩
Hello,大家好!下面是对torch.load函数的翻译~
torch.
load
(f, map_location=None, pickle_module=<module 'pickle' from '/opt/conda/lib/python3.6/pickle.py'>, **pickle_load_args)从文件加载用torch.save()保存的对象。
torch.load()使用Python的拆封(Unpickling:从存储的字符串文件中提取原始Python对象的过程,叫做unpickling)能力,但是会特别处理张量的存储空间。他们首先在CPU上反序列化,然后移动到保存他们的设备上。如果这个过程失败了(例如,因为运行系统没有某些设备),会导致例外的发生。然而,可以通过map_location参数将存储空间动态映射到另一套设备上。
如果map_location是可调用的,它将为每个序列化存储空间调用一次,并含有两个参数:storage和location。storage参数将会是CPU上的存储空间的初始反序列化。每一个序列化的存储空间都有一个与之相关的location标签,该标签表示了它存储的设备,并且是传给map_location的第二个参数。CPU张量的location标签是'cpu'
,CUDA张量的location标签为'cuda:device_id'
(e.g. 'cuda:2'
)。map_location应该返回None或存储空间。如果map_location返回的是存储空间,它将被用作最后的反序列化对象,并且该对象已经被移到正确的设备上。否则,torch.load()就会恢复默认模式,就好像map_location没有指定一样。
如果map_location是torch.device对象或包含设备标签的字符串,那么它指明了所有张量将要被加载的位置。
相反,如果map_location是一个字典,它将被用作将出现在文件(keys)中的位置标签重新映射到那些明确指明存储空间(values)位置的标签。
用户扩展可以使用torch.serialization.register_package()
来注册自己的位置标签、标记和反序列化方法。
但你在一个包含GPU张量的文件上调用torch.load()时,那些张量将会被默认的加载到GPU。当加载模型检查点时,可以通过调用torch.load(.., map_location='cpu'),然后load_state_dict(),来避免GPU RAM的激增。
默认情况下,我们将字节字符串解码为utf-8。这是为了避免在python3加载由python2保存的文件时出现的普遍的错误UnicodeDecodeError: 'ascii' codec can't decode byte 0x...。如果这个默认值是错误的话,你可以使用额外的encoding关键字参数来明确这些对象是如何被加载的,如encoding='latin1'用latin1将他们解码到字符串,encoding='bytes'将他们保持为字节数组,从而后续可以被byte_array.decode(...)解码。
例子
- >>> torch.load('tensors.pt')
- # Load all tensors onto the CPU
- >>> torch.load('tensors.pt', map_location=torch.device('cpu'))
- # Load all tensors onto the CPU, using a function
- >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
- # Load all tensors onto GPU 1
- >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
- # Map tensors from GPU 1 to GPU 0
- >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
- # Load tensor from io.BytesIO object
- >>> with open('tensor.pt', 'rb') as f:
- buffer = io.BytesIO(f.read())
- >>> torch.load(buffer)
- # Load a module with 'ascii' encoding for unpickling
- >>> torch.load('module.pt', encoding='ascii')
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。