当前位置:   article > 正文

pytorch中unsqueeze用法说明

pytorch中unsqueeze用法说明

在指定的位置插入一个维度,有两个参数,input是输入的tensor,dim是要插到的维度

需要注意的是dim的范围是[-input.dim()-1, input.dim()+1),是一个左闭右开的区间,当dim为负值时,会自动转换为dim = dim+input.dim()+1,类似于使用负数对python列表进行切片。

  1. import torch
  2. a = torch.randn(2,5)
  3. print(a)
  4. print("")
  5. b = a.unsqueeze(0)
  6. print(b.shape)
  7. print("")
  8. c = a.unsqueeze(a.dim())
  9. print(c.shape)
  10. 输出:
  11. tensor([[-0.4734, 0.4115, -0.9415, -1.1280, -0.1065],
  12. [ 0.1613, 1.2594, 1.1261, 1.3881, 0.1112]])
  13. torch.Size([1, 2, 5])
  14. torch.Size([2, 5, 1])

以上是二维数据情况:

首先生成了一个二维矩阵,其大小为[2,5]

然后,在0维度上插入一个维度,可以看到现在新矩阵a的形状变为[1,2,5],第0维度的大小默认是1

最后,在最后一个维度上插入一个维度,形状变为[2, 5, 1]

  1. a=torch.rand(2,3,2)
  2. print("")
  3. print("torch.unsqueeze(a,3) size: {}".format(torch.unsqueeze(a,3).size()))
  4. print("")
  5. print("torch.unsqueeze(a,2) size: {}".format(torch.unsqueeze(a,2).size()))
  6. print("")
  7. print("torch.unsqueeze(a,1) size: {}".format(torch.unsqueeze(a,1).size()))
  8. print("")
  9. print("torch.unsqueeze(a,0) size: {}".format(torch.unsqueeze(a,0).size()))
  10. print("")
  11. print("torch.unsqueeze(a,-1) size: {}".format(torch.unsqueeze(a,-1).size()))
  12. print("")
  13. print("torch.unsqueeze(a,-2) size: {}".format(torch.unsqueeze(a,-2).size()))
  14. print("")
  15. print("torch.unsqueeze(a,-3) size: {}".format(torch.unsqueeze(a,-3).size()))
  16. print("")
  17. print("torch.unsqueeze(a,-4) size: {}".format(torch.unsqueeze(a,-4).size()))
  18. 输出:
  19. torch.unsqueeze(a,3) size: torch.Size([2, 3, 2, 1])
  20. torch.unsqueeze(a,2) size: torch.Size([2, 3, 1, 2])
  21. torch.unsqueeze(a,1) size: torch.Size([2, 1, 3, 2])
  22. torch.unsqueeze(a,0) size: torch.Size([1, 2, 3, 2])
  23. torch.unsqueeze(a,-1) size: torch.Size([2, 3, 2, 1])
  24. torch.unsqueeze(a,-2) size: torch.Size([2, 3, 1, 2])
  25. torch.unsqueeze(a,-3) size: torch.Size([2, 1, 3, 2])
  26. torch.unsqueeze(a,-4) size: torch.Size([1, 2, 3, 2])

对于三维数据input.dim() = 3,因此dim的范围是[-4, 4)

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号