当前位置:   article > 正文

tensor.clone() 和 tensor.detach()

tensor.clone()

1 tensor.clone()

返回tensor的拷贝,返回的新tensor和原来的tensor具有同样的大小和数据类型

  • 原tensor的requires_grad=True
  • clone()返回的tensor是中间节点,梯度会流向原tensor,即返回的tensor的梯度会叠加在原tensor上
  1. >>> import torch
  2. >>> a = torch.tensor(1.0, requires_grad=True)
  3. >>> b = a.clone()
  4. >>> id(a), id(b) # a和b不是同一个对象
  5. (140191154302240, 140191145593424)
  6. >>> a.data_ptr(), b.data_ptr() # 也不指向同一块内存地址
  7. (94724518544960, 94724519185792)
  8. >>> a.requires_grad, b.requires_grad # 但b的requires_grad属性和a的一样,同样是True
  9. (True, True)
  10. >>> c = a * 2
  11. >>> c.backward()
  12. >>> a.grad
  13. tensor(2.)
  14. >>> d = b * 3
  15. >>> d.backward()
  16. >>> b.grad # b的梯度值为None,因为是中间节点,梯度值不会被保存
  17. >>> a.grad # b的梯度叠加在a上
  18. tensor(5.)
  • 原tensor的requires_grad=False
  1. >>> import torch
  2. >>> a = torch.tensor(1.0)
  3. >>> b = a.clone()
  4. >>> id(a), id(b) # a和b不是同一个对象
  5. (140191169099168, 140191154762208)
  6. >>> a.data_ptr(), b.data_ptr() # 也不指向同一块内存地址
  7. (94724519502912, 94724519533952)
  8. >>> a.requires_grad, b.requires_grad # 但b的requires_grad属性和a的一样,同样是False
  9. (False, False)
  10. >>> b.requires_grad_()
  11. >>> c = b * 2
  12. >>> c.backward()
  13. >>> b.grad
  14. tensor(2.)
  15. >>> a.grad # None

2 tensor.detach()

从计算图中脱离出来。

返回一个新的tensor,新的tensor和原来的tensor共享数据内存,但不涉及梯度计算,即requires_grad=False。修改其中一个tensor的值,另一个也会改变,因为是共享同一块内存,但如果对其中一个tensor执行某些内置操作,则会报错,例如resize_、resize_as_、set_、transpose_。

  1. >>> import torch
  2. >>> a = torch.rand((3, 4), requires_grad=True)
  3. >>> b = a.detach()
  4. >>> id(a), id(b) # a和b不是同一个对象了
  5. (140191157657504, 140191161442944)
  6. >>> a.data_ptr(), b.data_ptr() # 但指向同一块内存地址
  7. (94724518609856, 94724518609856)
  8. >>> a.requires_grad, b.requires_grad # b的requires_grad为False
  9. (True, False)
  10. >>> b[0][0] = 1
  11. >>> a[0][0] # 修改b的值,a的值也会改变
  12. tensor(1., grad_fn=<SelectBackward>)
  13. >>> b.resize_((4, 3)) # 报错
  14. RuntimeError: set_sizes_contiguous is not allowed on a Tensor created from .data or .detach().

3. tensor.clone().detach() 还是 tensor.detach().clone()

两者的结果是一样的,即返回的tensor和原tensor在梯度上或者数据上没有任何关系,一般用前者。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/272346
推荐阅读
相关标签
  

闽ICP备14008679号