赞
踩
torch.no_grad() 是 PyTorch 中的一个上下文管理器,用于在进入该上下文时禁用梯度计算。这在你只关心评估模型,而不是训练模型时非常有用,因为它可以显著减少内存使用并加速计算。
当你在 torch.no_grad() 上下文管理器中执行张量操作时,PyTorch 不会为这些操作计算梯度。这意味着不会在 .grad 属性中累积梯度,并且操作会更快地执行。
import torch
# 创建一个需要梯度的张量
x = torch.tensor([1.0], requires_grad=True)
# 使用 no_grad() 上下文管理器
with torch.no_grad():
y = x * 2
y.backward()
print(x.grad)
输出:
RuntimeError Traceback (most recent call last) Cell In[52], line 11 7 with torch.no_grad(): 8 y = x * 2 ---> 11 y.backward() 13 print(x.grad) File E:\anaconda\lib\site-packages\torch\_tensor.py:396, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs) 387 if has_torch_function_unary(self): 388 return handle_torch_function( 389 Tensor.backward, 390 (self,), (...) 394 create_graph=create_graph, 395 inputs=inputs) --> 396 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) File E:\anaconda\lib\site-packages\torch\autograd\__init__.py:173, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs) 168 retain_graph = create_graph 170 # The reason we repeat same the comment below is that 171 # some Python versions print out the first line of a multi-line function 172 # calls in the traceback and some print out the last line --> 173 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 174 tensors, grad_tensors_, retain_graph, create_graph, inputs, 175 allow_unreachable=True, accumulate_grad=True) RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
输出错误,因为使用了with torch.no_grad():。
import torch
# 创建一个需要梯度的张量
x = torch.tensor([1.0], requires_grad=True)
# 使用 no_grad() 上下文管理器
y = x * 2
y.backward()
print(x.grad)
输出:
tensor([2.])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。