当前位置:   article > 正文

【Pytorch】训练中跳过问题样本,解决显存爆炸\波动问题

【Pytorch】训练中跳过问题样本,解决显存爆炸\波动问题

        最近在训练模型时发现数据集部分数据存在问题,计算得到的loss非常大,这不利于模型的优化,所以就考虑通将loss过大的样本直接跳过,不使用这些数据进行优化。

        考虑让问题数据不在反向传播,用以下语句:

  1. loss = nll_with_covariances(
  2. xy_future_gt, coordinates, probas, data["target/future/valid"].squeeze(-1),
  3. covariance_matrices) * loss_coeff
  4. if loss>1e5 and step>100:
  5. del data
  6. torch.cuda.empty_cache()
  7. continue
  8. else:
  9. train_losses.append(loss.item())
  10. loss.backward()
  11. optimizer.step()

        就发生了显存不断上下波动的情况,按照原来的batchsize还发生显存不足的问题,尝试使用torch.cuda.empty_cache()依然不能解决。于是就想搞清楚pytorch训练中显存的管理机制。从知乎上看到一篇很详细的讲解,pytorch占用显存有四部分组成:模型定义、前向传播过程、反向传播过程、参数更新过程。 PyTorch显存机制分析

        我分析反向传播结束后会清除前向传播数据,可以使模型再次前向传播而显存不发生变化,之前直接跳过反向传播,导致有两个batchsize的前向传播数据挤在显存中,导致显存猛增最后不足。

        torch.cuda.empty_cache()只能清除缓存区的数据,不能解决根本问题,会导致预留给下一批次位置被清除,导致训练速度变慢,由平均1.1s增加到1.25s。

       之后在pytorch论坛上看到和我相似的问题,他的解决方法就是对问题样本仍进行反向传播,但是不更新参数,所以修改后的代码为:How to skip backward if the loss is very small in training

  1. loss = nll_with_covariances(
  2. xy_future_gt, coordinates, probas, data["target/future/valid"].squeeze(-1),
  3. covariance_matrices) * loss_coeff
  4. train_losses.append(loss.item())
  5. loss.backward()
  6. if loss>1e5 and step>100:
  7. optimizer.zero_grad()
  8. del data
  9. continue
  10. else:
  11. optimizer.step()

         修改后的模型训练中显存就始终保持稳定了,不会像之前一样发生剧烈的显存波动,这时的batchsize可以进一步调大,不用担心训练中显存爆炸中断了。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/351404?site
推荐阅读
相关标签
  

闽ICP备14008679号