赞
踩
目录
遇到一个张量合并问题,对于dim=-1的合并情况不清楚,在寻找博客案例时都没有提到dim=-1的情况,所以自己写了几个案例补充一下对于维度为-1的情况时怎样的
对于理解torch.cat()和torch.stack()直接用例子理解是最好用的。特别注意dim维度的理解
给出以下两个张量的例子:
- x1 = torch.tensor([[11,21,31],[34,51,23]], dtype=torch.int)
- x2 = torch.tensor([[34,56,67],[62,91,56]], dtype=torch.int)
其实乍一看这个len(inputs[0])非常难理解,所以不要这么去记。
先来看例子:
1.对x1,x2这两个张量进行合并,注意此时这两个张量的维数都为2并且其形状都为(2,3)
1):在维数等于0时进行合并
x3 = torch.cat((x1,x2), dim=0)
- x3: tensor([[11, 21, 31],
- [34, 51, 23],
- [34, 56, 67],
- [62, 91, 56]], dtype=torch.int32)
- x3 shape: torch.Size([4, 3])
2):在维数等于1时进行合并
x3 = torch.cat((x1,x2), dim=1)
- x3: tensor([[11, 21, 31, 34, 56, 67],
- [34, 51, 23, 62, 91, 56]], dtype=torch.int32)
- x3 shape: torch.Size([2, 6])
3):在维数等于2时进行合并
x3 = torch.cat((x1,x2), dim=2)
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
在dim=2时出现报错
-------------------------------------------------------------------------------------------------------------------------------
看到这里可能还是不明白为什么在维度等于2时出现了报错,这个维度究竟是怎么计算的
所以我们再次扩展一下x1和x2使其成为三维张量
- x1 = torch.tensor([[[11,21,31],[34,51,23]],[[76,56,89],[35,74,62]]], dtype=torch.int)
- x2 = torch.tensor([[[34,56,67],[62,91,56]],[[67,35,98],[38,79,25]]], dtype=torch.int)
那么现在这个两个张量的形状为(2,2,3)
2. 对于x1和x2再次进行合并
- x3 = torch.cat((x1,x2), dim=0)
- x3: tensor([[[11, 21, 31],
- [34, 51, 23]],
-
- [[76, 56, 89],
- [35, 74, 62]],
-
- [[34, 56, 67],
- [62, 91, 56]],
-
- [[67, 35, 98],
- [38, 79, 25]]], dtype=torch.int32)
- x3 shape: torch.Size([4, 2, 3])
-
- x3 = torch.cat((x1,x2), dim=1)
- x3: tensor([[[11, 21, 31],
- [34, 51, 23],
- [34, 56, 67],
- [62, 91, 56]],
-
- [[76, 56, 89],
- [35, 74, 62],
- [67, 35, 98],
- [38, 79, 25]]], dtype=torch.int32)
- x3 shape: torch.Size([2, 4, 3])
-
- x3 = torch.cat((x1,x2), dim=2)
- x3: tensor([[[11, 21, 31, 34, 56, 67],
- [34, 51, 23, 62, 91, 56]],
-
- [[76, 56, 89, 67, 35, 98],
- [35, 74, 62, 38, 79, 25]]], dtype=torch.int32)
- x3 shape: torch.Size([2, 2, 6])
-
- x3 = torch.cat((x1,x2), dim=3)
- # IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
可以观察到这次在dim为2时并没有出现报错,因为这一次输入的x1和x2的形状大小为(2,2,3)有三个维度的信息,所以dim可以取值0,1,2
总结:dim指代的是以哪一个维度信息进行合并,在Pytorch中维度信息的表示方式是从外向内的。这一过程可以形象的理解为拨香蕉。(可以回看张量例子中的x1,x2的形状)所以当dim=0时指代的是按照最外层的list进行拼接,当dim=1时指代按照第二层list进行拼接,以此类推。所以输入的张量维数决定dim的取值范围。再会看定义的范围就非常好理解了
先说结论:dim=-1与dim=len(inputs[0])一样,即适用最里面的list进行合并。在python中-1的用法类似,均指代最后一个,所以dim=-1的情况就不难理解了
接下来说一下dim=-1的优点:
· 不需要知道输入具体的形状,适用于想要对于dim=len(inputs[0])进行扩展的情况
· 特别适合末端形状维度的扩展,举个例子⬇
仍然使用最开始的x1张量:
x1 = torch.tensor([[11,21,31],[34,51,23]], dtype=torch.int)
另外增加一个零向量
x0 = torch.zeros(x1.shape[0], 1)
它打印出来的结果为
- x0:
- tensor([[0.],[0.]])
- torch.Size([2, 1])
将x0和x1合并:
- x3 = torch.cat((x1,x0), dim=-1)
-
- x3: tensor([[11., 21., 31., 0.],
- [34., 51., 23., 0.]])
- x3 shape: torch.Size([2, 4])
可以看到形状从(2,3)扩充为(2,4)
在很多的数据处理中,这一步十分重要
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。