赞
踩
PyTorch采用动态图机制,通过tensor(以前是variable)来构建图,tensor里面包含的梯度信息用于反向传播求导。但不是所有变量都应该包含梯度(毕竟东西多,占“面积”就多),否则就会造成网络越跑,所占显存越大的情况,那怎么办呢?
先看一段代码,
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_func(outputs, targets)
loss.backward()
optimizer.step()
print('loss:',loss)
print('loss.item():',loss.item())
print('loss.detach():',loss.detach())
train_loss += loss.item() # <----关键
...
输出:
loss: tensor(2.3391, device='cuda:0', grad_fn=<NllLossBackward>)
loss.item(): 2.3391051292419434
loss.detach(): tensor(2.3391, device='cuda:0')
很明显,loss.backward()在上面已经进行过了,下面去计算train_loss的时候就不要再带有梯度信息才合适。故有两种解决方案:
loss.detach()
来获取不需要梯度回传的部分。loss.item()
直接获得对应的python数据类型。https://www.zhihu.com/question/67209417/answer/344752405
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。