当前位置:   article > 正文

pytorch | pytorch改变tensor维度的方法

pytorch | pytorch改变tensor维度的方法

pytorch 的 Tensor 类有很多方法可以用来改变 tensor 的维度。这里介绍几种常用的方法:

  • view(shape):返回一个新的 tensor,它具有给定的形状。如果元素总数不变,则可以用它来改变 tensor 的维度。例如:
  1. import torch
  2. t = torch.tensor([
  3. [1, 2, 3],
  4. [4, 5, 6]
  5. ])
  6. print(t.shape) # torch.Size([2, 3])
  7. t_view = t.view(3, 2)
  8. print(t_view.shape) # torch.Size([3, 2])
  • unsqueeze(dim):返回一个新的 tensor,它的指定位置插入了一个新的维度。例如:
  1. import torch
  2. t = torch.tensor([
  3. [1, 2, 3],
  4. [4, 5, 6]
  5. ])
  6. print(t.shape) # torch.Size([2, 3])
  7. t_unsqueeze = t.unsqueeze(0)
  8. print(t_unsqueeze.shape) # torch.Size([1, 2, 3])
  9. t_unsqueeze = t.unsqueeze(1)
  10. print(t_unsqueeze.shape) # torch.Size([2, 1, 3])
  11. t_unsqueeze = t.unsqueeze(2)
  12. print(t_unsqueeze.shape) # torch.Size([2, 3, 1])
  • squeeze(dim):返回一个新的 tensor,它的指定位置的维度的大小为 1 的维度被删除。例如:
  1. import torch
  2. t = torch.tensor([
  3. [[1], [2], [3]],
  4. [[4], [5], [6]]
  5. ])
  6. print(t.shape) # torch.Size([2, 3, 1])
  7. t_squeeze = t.squeeze(2)
  8. print(t_squeeze.shape) # torch.Size([2, 3])
  9. t_squeeze = t.squeeze()
  10. print(t_squeeze.shape) # torch.Size([2, 3])
  • transpose(dim0, dim1):返回一个新的 tensor,它的排列被交换。例如:
  1. import torch
  2. t = torch.tensor([
  3. [1, 2, 3],
  4. [4, 5, 6]
  5. ])
  6. print(t.shape) # torch.Size([2, 3])
  7. t_transpose = t.transpose(0, 1)
  8. print(t_transpose.shape) # torch.Size([3, 2])
  9. t_transpose = t.transpose(1, 0)
  10. print(t_transpose.shape) # torch.Size([3, 2])

还有一些其他的方法,例如 permute() 和 contiguous(),可以用来改变 tensor 的维度。有关这些方法的更多信息,可以参考 pytorch 官方文档:https://pytorch.org/docs/stable/tensors.html。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/325418
推荐阅读
相关标签
  

闽ICP备14008679号