当前位置:   article > 正文

Pytorch 训练与测试时爆显存(out of memory)的一个解决方案_爆显存 英文

爆显存 英文

Pytorch 训练时有时候会因为加载的东西过多而爆显存,有些时候这种情况还可以使用cuda的清理技术进行修整,当然如果模型实在太大,那也没办法。

使用torch.cuda.empty_cache()删除一些不需要的变量代码示例如下:

  1. try:
  2. output = model(input)
  3. except RuntimeError as exception:
  4. if "out of memory" in str(exception):
  5. print("WARNING: out of memory")
  6. if hasattr(torch.cuda, 'empty_cache'):
  7. torch.cuda.empty_cache()
  8. else:
  9. raise exception

测试的时候爆显存有可能是忘记设置no_grad, 示例代码如下:

  1. with torch.no_grad():
  2. for ii,(inputs,filelist) in tqdm(enumerate(test_loader), desc='predict'):
  3. if opt.use_gpu:
  4. inputs = inputs.cuda()
  5. if len(inputs.shape) < 4:
  6. inputs = inputs.unsqueeze(1)
  7. else:
  8. if len(inputs.shape) < 4:
  9. inputs = torch.transpose(inputs, 1, 2)
  10. inputs = inputs.unsqueeze(1)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/220594?site
推荐阅读
相关标签
  

闽ICP备14008679号