当前位置:   article > 正文

【Pytorch】torch.cat() Pytorch中对于张量拼接函数的用法和示例以及dim=-1的应用场景_torch.cat([qpos, padding], dim=-1)

torch.cat([qpos, padding], dim=-1)

目录

背景

张量例子

dim取值范围在[0, len(inputs[0])]时

dim取值为-1时


背景

遇到一个张量合并问题,对于dim=-1的合并情况不清楚,在寻找博客案例时都没有提到dim=-1的情况,所以自己写了几个案例补充一下对于维度为-1的情况时怎样的

张量例子

对于理解torch.cat()和torch.stack()直接用例子理解是最好用的。特别注意dim维度的理解

给出以下两个张量的例子:

  1. x1 = torch.tensor([[11,21,31],[34,51,23]], dtype=torch.int)
  2. x2 = torch.tensor([[34,56,67],[62,91,56]], dtype=torch.int)

dim取值范围在[0, len(inputs[0])]时

其实乍一看这个len(inputs[0])非常难理解,所以不要这么去记。

先来看例子:

1.对x1,x2这两个张量进行合并,注意此时这两个张量的维数都为2并且其形状都为(2,3)

1):在维数等于0时进行合并

x3 = torch.cat((x1,x2), dim=0)
  1. x3: tensor([[11, 21, 31],
  2. [34, 51, 23],
  3. [34, 56, 67],
  4. [62, 91, 56]], dtype=torch.int32)
  5. x3 shape: torch.Size([4, 3])

2):在维数等于1时进行合并

x3 = torch.cat((x1,x2), dim=1)
  1. x3: tensor([[11, 21, 31, 34, 56, 67],
  2. [34, 51, 23, 62, 91, 56]], dtype=torch.int32)
  3. x3 shape: torch.Size([2, 6])

3):在维数等于2时进行合并

x3 = torch.cat((x1,x2), dim=2)

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

在dim=2时出现报错

-------------------------------------------------------------------------------------------------------------------------------

看到这里可能还是不明白为什么在维度等于2时出现了报错,这个维度究竟是怎么计算的

所以我们再次扩展一下x1和x2使其成为三维张量

  1. x1 = torch.tensor([[[11,21,31],[34,51,23]],[[76,56,89],[35,74,62]]], dtype=torch.int)
  2. x2 = torch.tensor([[[34,56,67],[62,91,56]],[[67,35,98],[38,79,25]]], dtype=torch.int)

那么现在这个两个张量的形状为(2,2,3)

2. 对于x1和x2再次进行合并

  1. x3 = torch.cat((x1,x2), dim=0)
  2. x3: tensor([[[11, 21, 31],
  3. [34, 51, 23]],
  4. [[76, 56, 89],
  5. [35, 74, 62]],
  6. [[34, 56, 67],
  7. [62, 91, 56]],
  8. [[67, 35, 98],
  9. [38, 79, 25]]], dtype=torch.int32)
  10. x3 shape: torch.Size([4, 2, 3])
  11. x3 = torch.cat((x1,x2), dim=1)
  12. x3: tensor([[[11, 21, 31],
  13. [34, 51, 23],
  14. [34, 56, 67],
  15. [62, 91, 56]],
  16. [[76, 56, 89],
  17. [35, 74, 62],
  18. [67, 35, 98],
  19. [38, 79, 25]]], dtype=torch.int32)
  20. x3 shape: torch.Size([2, 4, 3])
  21. x3 = torch.cat((x1,x2), dim=2)
  22. x3: tensor([[[11, 21, 31, 34, 56, 67],
  23. [34, 51, 23, 62, 91, 56]],
  24. [[76, 56, 89, 67, 35, 98],
  25. [35, 74, 62, 38, 79, 25]]], dtype=torch.int32)
  26. x3 shape: torch.Size([2, 2, 6])
  27. x3 = torch.cat((x1,x2), dim=3)
  28. # IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

可以观察到这次在dim为2时并没有出现报错,因为这一次输入的x1和x2的形状大小为(2,2,3)有三个维度的信息,所以dim可以取值0,1,2

总结:dim指代的是以哪一个维度信息进行合并,在Pytorch中维度信息的表示方式是从外向内的。这一过程可以形象的理解为拨香蕉。(可以回看张量例子中的x1,x2的形状)所以当dim=0时指代的是按照最外层的list进行拼接,当dim=1时指代按照第二层list进行拼接,以此类推。所以输入的张量维数决定dim的取值范围。再会看定义的范围就非常好理解了

dim取值为-1时

先说结论:dim=-1与dim=len(inputs[0])一样,即适用最里面的list进行合并。在python中-1的用法类似,均指代最后一个,所以dim=-1的情况就不难理解了

接下来说一下dim=-1的优点:

· 不需要知道输入具体的形状,适用于想要对于dim=len(inputs[0])进行扩展的情况

· 特别适合末端形状维度的扩展,举个例子⬇

仍然使用最开始的x1张量:

x1 = torch.tensor([[11,21,31],[34,51,23]], dtype=torch.int)

另外增加一个零向量

x0 = torch.zeros(x1.shape[0], 1)

它打印出来的结果为

  1. x0:
  2. tensor([[0.],[0.]])
  3. torch.Size([2, 1])

将x0和x1合并:

  1. x3 = torch.cat((x1,x0), dim=-1)
  2. x3: tensor([[11., 21., 31., 0.],
  3. [34., 51., 23., 0.]])
  4. x3 shape: torch.Size([2, 4])

可以看到形状从(2,3)扩充为(2,4)

在很多的数据处理中,这一步十分重要

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/黑客灵魂/article/detail/750852
推荐阅读
相关标签
  

闽ICP备14008679号