当前位置:   article > 正文

PyTorch Tensor类:clone和detach的用法_pytorch tensor clone

pytorch tensor clone

本来是准备分析一下Tensor类的源码的,但是看了看发现这个类的源码实现基本都是在C++上,目前精力有限,所以就算了。现在打算分析一下Tensor中可能比较难用的方法,比如clone,detach。

这些方法之所以难用主要还是因为Tensor支持自动微分,也就是说每个Tensor不止能表示这个Tensor对应的值,还可以表示以这个Tensor为根结点的前向计算图。

Clone方法

我们先看PyTorch的官方文档torch.clone — PyTorch 1.10.0 documentation

 很多人可能对clone是可微的这句话不是很懂,其实就是论文里偶尔能见到的identify操作,我这里举个例子。

  1. import torch
  2. x = torch.tensor([1.0],requires_grad = True)
  3. y = x.clone()
  4. y.backward()
  5. print("x.grad:",x.grad)
  6. """
  7. 输出结果:
  8. x.grad: tensor([1.])
  9. """

可以看到我们对y进行反向传播,x的梯度为1。所以实际上y = x.clone()类似于数学表达式中的y=x,但是在python里如果让y=x,是不会给y创建新的内存空间的,这就需要clone了。

或者也可以y = x+0和y=x*1。

  1. import torch
  2. x = torch.tensor([1.0],requires_grad = True)
  3. y = x+0
  4. y.backward()
  5. print("x.grad:",x.grad)
  6. """
  7. 输出结果:
  8. x.grad: tensor([1.])
  9. """

可以看的,和clone的效果是一致的。

下面以计算图的形式来表述:

(当前计算图中结点代表操作(叶子节点除外),边代表张量)

Detach方法

detach也是一个和计算图关联比较紧密的tensor方法。

还是先看一下PyTorch的官方文档torch.Tensor.detach — PyTorch 1.10.0 documentation

Returns a new Tensor, detached from the current graph.The result will never require gradient.Returned Tensor shares the same storage with the original one. In-place modifications on either of them will be seen, and may trigger errors in correctness checks.

也就是说detach的效果其实就是将一个Tensor的requires_grad置为False。那么一个Tensor的requires_grad为False对这个计算图会有什么影响呢?从效果上来看其实就是以该Tensor为根节点的计算子图在反向传播中不再会得到梯度。

这里继续看一个例子。

  1. import torch
  2. #L = (X+Y)×Z
  3. #dL/dX = Z, dL/dY = Z, dL/dZ = X+Y
  4. X = torch.tensor([1.0],requires_grad = True)
  5. Y = torch.tensor([2.0],requires_grad = True)
  6. Z = torch.tensor([3.0],requires_grad = True)
  7. K = X+Y
  8. L = K*Z
  9. L.backward()
  10. print("X.grad:",X.grad)
  11. print("Y.grad:",Y.grad)
  12. print("Z.grad:",Z.grad)
  13. """
  14. 输出结果:
  15. X.grad: tensor([3.])
  16. Y.grad: tensor([3.])
  17. Z.grad: tensor([3.])
  18. """

这是一个正常反向传播的情况下的梯度值,接下来我将利用detach使梯度无法反传到X和Y。

  1. import torch
  2. #L = (X+Y)×Z
  3. #dL/dX = Z, dL/dY = Z, dL/dZ = X+Y
  4. X = torch.tensor([1.0],requires_grad = True)
  5. Y = torch.tensor([2.0],requires_grad = True)
  6. Z = torch.tensor([3.0],requires_grad = True)
  7. K = X+Y
  8. L = K.detach()*Z
  9. L.backward()
  10. print("X.grad:",X.grad)
  11. print("Y.grad:",Y.grad)
  12. print("Z.grad:",Z.grad)
  13. """
  14. 输出结果:
  15. X.grad: None
  16. Y.grad: None
  17. Z.grad: tensor([3.])
  18. """

从计算图上看是这样的(虚线代表无法到达)

 也就是说由于K.detach,导致当前计算中K的requires_grad为False,因此dL/dK这个梯度无法传播到K,从而导致dL/dX和dL/dY都无法被计算。

但是需要注意的是K.detach()并不会修改K本身的requires_grad属性,因此K本身还是可以接收梯度的。

下面是另一个例子:

  1. import torch
  2. #L = (X+Y)×3×Z
  3. #dL/dX = 3Z, dL/dY = 3Z, dL/dZ = 3(X+Y)
  4. X = torch.tensor([1.0],requires_grad = True)
  5. Y = torch.tensor([2.0],requires_grad = True)
  6. Z = torch.tensor([3.0],requires_grad = True)
  7. K = X+Y
  8. L = K.detach()*Z*K
  9. L.backward()
  10. print("X.grad:",X.grad)
  11. print("Y.grad:",Y.grad)
  12. print("Z.grad:",Z.grad)
  13. """
  14. 输出结果:
  15. X.grad: tensor([9.])
  16. Y.grad: tensor([9.])
  17. Z.grad: tensor([9.])
  18. """

这里我在计算L的时候同时用了K.detach()和K,K.detach带来的梯度虽然无法回传,但是K的梯度是可以回传的,所以X和Y依然存在梯度。

所以如果我们想让从K中出去的所有计算都无法得到梯度,那么应该先使用K = K.detach()。

所以这么来看,detach这个方法其实可以理解为在本次计算中让某个变量变为常数。

总结

目前对于clone和detach的原理分析的比较清楚了,但是具体的使用场景还需要在实践中继续观察,后续看到了我也会补充上来的。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/272364
推荐阅读
相关标签
  

闽ICP备14008679号