赞
踩
pytorch 的 Tensor 类有很多方法可以用来改变 tensor 的维度。这里介绍几种常用的方法:
- import torch
-
- t = torch.tensor([
- [1, 2, 3],
- [4, 5, 6]
- ])
- print(t.shape) # torch.Size([2, 3])
-
- t_view = t.view(3, 2)
- print(t_view.shape) # torch.Size([3, 2])
- import torch
-
- t = torch.tensor([
- [1, 2, 3],
- [4, 5, 6]
- ])
- print(t.shape) # torch.Size([2, 3])
-
- t_unsqueeze = t.unsqueeze(0)
- print(t_unsqueeze.shape) # torch.Size([1, 2, 3])
-
- t_unsqueeze = t.unsqueeze(1)
- print(t_unsqueeze.shape) # torch.Size([2, 1, 3])
-
- t_unsqueeze = t.unsqueeze(2)
- print(t_unsqueeze.shape) # torch.Size([2, 3, 1])
- import torch
-
- t = torch.tensor([
- [[1], [2], [3]],
- [[4], [5], [6]]
- ])
- print(t.shape) # torch.Size([2, 3, 1])
-
- t_squeeze = t.squeeze(2)
- print(t_squeeze.shape) # torch.Size([2, 3])
-
- t_squeeze = t.squeeze()
- print(t_squeeze.shape) # torch.Size([2, 3])
- import torch
-
- t = torch.tensor([
- [1, 2, 3],
- [4, 5, 6]
- ])
- print(t.shape) # torch.Size([2, 3])
-
- t_transpose = t.transpose(0, 1)
- print(t_transpose.shape) # torch.Size([3, 2])
-
- t_transpose = t.transpose(1, 0)
- print(t_transpose.shape) # torch.Size([3, 2])
还有一些其他的方法,例如 permute() 和 contiguous(),可以用来改变 tensor 的维度。有关这些方法的更多信息,可以参考 pytorch 官方文档:https://pytorch.org/docs/stable/tensors.html。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。