当前位置:   article > 正文

【item() detach()用法】神经网络训练显存越来越大的原因之一_loss.detach()

loss.detach()

1 显存变大的原因

PyTorch采用动态图机制,通过tensor(以前是variable)来构建图,tensor里面包含的梯度信息用于反向传播求导。但不是所有变量都应该包含梯度(毕竟东西多,占“面积”就多),否则就会造成网络越跑,所占显存越大的情况,那怎么办呢?

2 loss.item()和loss.detach()解决问题

先看一段代码,

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_func(outputs, targets)
        loss.backward()
        optimizer.step()
        
        print('loss:',loss)
        print('loss.item():',loss.item())
        print('loss.detach():',loss.detach())
        train_loss += loss.item()			# <----关键
        ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

输出:

loss: tensor(2.3391, device='cuda:0', grad_fn=<NllLossBackward>)
loss.item(): 2.3391051292419434
loss.detach(): tensor(2.3391, device='cuda:0')
  • 1
  • 2
  • 3

很明显,loss.backward()在上面已经进行过了,下面去计算train_loss的时候就不要再带有梯度信息才合适。故有两种解决方案:

  • 使用loss.detach()来获取不需要梯度回传的部分。
    detach()通过重新声明一个变量,指向原变量的存放位置,但是requires_grad变为False。
  • 使用loss.item()直接获得对应的python数据类型。
    建议: 把除了loss.backward()之外的loss调用都改成loss.item()

3 感谢链接

https://www.zhihu.com/question/67209417/answer/344752405
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/494897
推荐阅读
相关标签
  

闽ICP备14008679号