当前位置:   article > 正文

Pytorch 训练与测试时爆显存(out of memory)的一个解决方案_pytorch 缓存

pytorch 缓存

Pytorch 训练时有时候会因为加载的东西过多而爆显存,有些时候这种情况还可以使用cuda的清理技术进行修整,当然如果模型实在太大,那也没办法。

使用torch.cuda.empty_cache()删除一些不需要的变量代码示例如下:

  1. try:
  2. output = model(input)
  3. except RuntimeError as exception:
  4. if "out of memory" in str(exception):
  5. print("WARNING: out of memory")
  6. if hasattr(torch.cuda, 'empty_cache'):
  7. torch.cuda.empty_cache()
  8. else:
  9. raise exception

测试的时候爆显存有可能是忘记设置no_grad, 示例代码如下:

  1. with torch.no_grad():
  2. for ii,(inputs,filelist) in tqdm(enumerate(test_loader), desc='predict'):
  3. if opt.use_gpu:
  4. inputs = inputs.cuda()
  5. if len(inputs.shape) < 4:
  6. inputs = inputs.unsqueeze(1)
  7. else:
  8. if len(inputs.shape) < 4:
  9. inputs = torch.transpose(inputs, 1, 2)
  10. inputs = inputs.unsqueeze(1)

 

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号