当前位置:   article > 正文

Pytorch中torch.autograd.grad()函数用法示例

torch.autograd.grad()

目录

一、函数解释

二、代码范例(y=x^2)


一、函数解释

如果输入x,输出是y,则求y关于x的导数(梯度):result=dydx

  1. def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False,
  2. only_inputs=True, allow_unused=False):
  3. r"""Computes and returns the sum of gradients of outputs w.r.t. the inputs.
  4. ``grad_outputs`` should be a sequence of length matching ``output``
  5. containing the pre-computed gradients w.r.t. each of the outputs. If an
  6. output doesn't require_grad, then the gradient can be ``None``).
  7. If ``only_inputs`` is ``True``, the function will only return a list of gradients
  8. w.r.t the specified inputs. If it's ``False``, then gradient w.r.t. all remaining
  9. leaves will still be computed, and will be accumulated into their ``.grad``
  10. attribute.
  11. Arguments:
  12. outputs (sequence of Tensor): outputs of the differentiated function.
  13. inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be
  14. returned (and not accumulated into ``.grad``).
  15. grad_outputs (sequence of Tensor): Gradients w.r.t. each output.
  16. None values can be specified for scalar Tensors or ones that don't require
  17. grad. If a None value would be acceptable for all grad_tensors, then this
  18. argument is optional. Default: None.
  19. retain_graph (bool, optional): If ``False``, the graph used to compute the grad
  20. will be freed. Note that in nearly all cases setting this option to ``True``
  21. is not needed and often can be worked around in a much more efficient
  22. way. Defaults to the value of ``create_graph``.
  23. create_graph (bool, optional): If ``True``, graph of the derivative will
  24. be constructed, allowing to compute higher order derivative products.
  25. Default: ``False``.
  26. allow_unused (bool, optional): If ``False``, specifying inputs that were not
  27. used when computing outputs (and therefore their grad is always zero)
  28. is an error. Defaults to ``False``.
  29. """
  30. if not only_inputs:
  31. warnings.warn("only_inputs argument is deprecated and is ignored now "
  32. "(defaults to True). To accumulate gradient for other "
  33. "parts of the graph, please use torch.autograd.backward.")
  34. outputs = (outputs,) if isinstance(outputs, torch.Tensor) else tuple(outputs)
  35. inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
  36. if grad_outputs is None:
  37. grad_outputs = [None] * len(outputs)
  38. elif isinstance(grad_outputs, torch.Tensor):
  39. grad_outputs = [grad_outputs]
  40. else:
  41. grad_outputs = list(grad_outputs)
  42. grad_outputs = _make_grads(outputs, grad_outputs)
  43. if retain_graph is None:
  44. retain_graph = create_graph
  45. return Variable._execution_engine.run_backward(
  46. outputs, grad_outputs, retain_graph, create_graph,
  47. inputs, allow_unused)

二、代码范例(y=x^2)

  1. import torch
  2. x = torch.randn(3, 4).requires_grad_(True)
  3. for i in range(3):
  4. for j in range(4):
  5. x[i][j] = i + j
  6. y = x ** 2
  7. print(x)
  8. print(y)
  9. weight = torch.ones(y.size())
  10. print(weight)
  11. dydx = torch.autograd.grad(outputs=y,
  12. inputs=x,
  13. grad_outputs=weight,
  14. retain_graph=True,
  15. create_graph=True,
  16. only_inputs=True)
  17. """(x**2)' = 2*x """
  18. print(dydx[0])
  19. d2ydx2 = torch.autograd.grad(outputs=dydx[0],
  20. inputs=x,
  21. grad_outputs=weight,
  22. retain_graph=True,
  23. create_graph=True,
  24. only_inputs=True)
  25. print(d2ydx2[0])

x是:

  1. tensor([[0., 1., 2., 3.],
  2. [1., 2., 3., 4.],
  3. [2., 3., 4., 5.]], grad_fn=<CopySlices>)

y = x的平方:

  1. tensor([[ 0., 1., 4., 9.],
  2. [ 1., 4., 9., 16.],
  3. [ 4., 9., 16., 25.]], grad_fn=<PowBackward0>)

weight:

  1. tensor([[1., 1., 1., 1.],
  2. [1., 1., 1., 1.],
  3. [1., 1., 1., 1.]])

dydx就是dydx=2x(一阶导数),得到结果还需要乘以weight:

  1. tensor([[ 0., 2., 4., 6.],
  2. [ 2., 4., 6., 8.],
  3. [ 4., 6., 8., 10.]], grad_fn=<ThMulBackward>)

d2ydx2就是d2ydx2=(2x)=2(二阶导数),得到结果还需要乘以weight: 

  1. tensor([[2., 2., 2., 2.],
  2. [2., 2., 2., 2.],
  3. [2., 2., 2., 2.]], grad_fn=<ThMulBackward>)

是不是很简单呢~

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/659876
推荐阅读
相关标签
  

闽ICP备14008679号