当前位置:   article > 正文

使用GPU训练模型遇到的问题_max_length影响显存吗

max_length影响显存吗

使用GPU训练模型,遇到显存不足的情况:开始报chunk xxx size 64000的错误。使用tensorflow框架来训练的。
仔细分析原因有两个:

  • 数据集padding依据的是整个训练数据集的max_seq_length,这样在一个批内的数据会造成额外的padding,占用显存;
  • 在训练时把整个训练数据先全部加载,造成显存占用多。

如果遇到第一种情况,即使使用CPU训练速度也非常慢。
对于第二种情况,要使用generator来解决。不要加载全部数据,要分批加载,根据一个批内的最大length来填充,同时也要限制最大length的长度。丢弃部分很长的数据。

而且,如果使用bert时,会对seq_length有限制。

tensorflow 1.12限制只使用CPU:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
  • 1
  • 2
声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号