当前位置:   article > 正文

关于模型训练中显存占用过大的或直接报显存爆炸的解决方法_百川模型怎么恢复显存

百川模型怎么恢复显存

模型训练显存爆炸解决方法

在模型训练中,应该理解梯度、反向传播、图层、显存这些概念,在模型训练过程中,一般会分为训练+验证+测试 ,在这些过程中,一般在训练过程中会比较占用显存,因为涉及到反向传播,需要大量的梯度,这些数据又存放在显存中。
在今天模型的训练中,突然发现可以训练,但是在验证过程中出现显存爆炸炸,提示我显存不足,我就很纳闷,一直在找问题,终于发现了:

在我的训练代码中:

   for epoch in range(0, epoch_num):
        net.train()
        for i, data in enumerate(train_dataloader):
            ite_num = ite_num + 1
            inputs, labels = data['image'], data['maskl']
            inputs = inputs.type(torch.FloatTensor)#注意第1行
            labels = labels.type(torch.FloatTensor)#注意第2行

            # wrap them in Variable:
            if torch.cuda.is_available():
                input, label = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)#注意第3行
            else:
                input, label = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)#注意第4行

            optimizer.zero_grad() 
            pred = net(input)
            loss = dice_bce_loss_fusion(pred,label)
            loss.backward() 
            optimizer.step() 
            running_loss += loss.data.item()#注意第5行
            del d0,loss 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

**分宜以上代码,我通过print(loss.requires_grad)发现loss是有梯度的,但是我在累加的时候用了loss.data.item(),这样就减小了显存的占用量,且使用了del loss,但是我的验证代码中,没有使用loss.data.item()**和d0,导致在验证时出现显存爆炸。
在模型训练过程中,应该知道训练和验证的交替:

# evaluate model:
model.eval()#切换到验证
with torch.no_grad():#注意这一行,使得在内的所有参数均没有梯度,加快模型的训练与验证
    ...
    out_data = model(data)#数据是没有梯度的

# training step
model.train()#恢复模型训练
    ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

以上就是解决在验证时报出显存爆炸的解决方法,特此记录一波。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/564986
推荐阅读