当前位置:   article > 正文

pytorch中torch.stack()用法虽简单,但不好理解

pytorch中torch.stack()用法虽简单,但不好理解

函数功能

沿一个新维度对输入一系列张量进行连接,序列中所有张量应为相同形状,stack 函数返回的结果会新增一个维度。也即是把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度上面进行堆叠。

参数列表

tensors :为一系列输入张量,类型为turple和List
dim :新增维度的(下标)位置,当dim = -1时默认最后一个维度;范围必须介于 0 到输入张量的维数之间,默认是dim=0,在第0维进行连接
返回值:输出新增维度后的张量

情况一:输入数据为1维数据

dim = 0 : 在第0维进行连接,相当于在行上进行组合(输入张量为一维,输出张量为两维)

  1. import torch
  2. a = torch.tensor([1, 2, 3])
  3. b = torch.tensor([11, 22, 33])
  4. #在第0维进行连接,相当于在行上进行组合,取a的一行,b的一行,构成一个新的tensor(输入张量为一维,输出张量为两维)
  5. c = torch.stack([a, b],dim=0)
  6. print(a)
  7. print(b)
  8. print(c.size())
  9. print(c)
  10. 输出:
  11. tensor([1, 2, 3])
  12. tensor([11, 22, 33])
  13. torch.Size([2, 3])
  14. tensor([[ 1, 2, 3],
  15. [11, 22, 33]])

dim = 1 :在第1维进行连接,相当于在对应行上面对列元素进行组合(输入张量为一维,输出张量为两维)

  1. import torch
  2. a = torch.tensor([1, 2, 3])
  3. b = torch.tensor([11, 22, 33])
  4. print(a)
  5. print(b)
  6. #在第1维进行连接,相当于在对应行上面对列元素进行组合,取a的一列,b的一列,构成新的tensor的一行(输入张量为一维,输出张量为两维)
  7. c = torch.stack([a, b],dim=1)
  8. print(c.size())
  9. print(c)
  10. 输出:
  11. tensor([1, 2, 3])
  12. tensor([11, 22, 33])
  13. torch.Size([3, 2])
  14. tensor([[ 1, 11],
  15. [ 2, 22],
  16. [ 3, 33]])

情况二:输入数据为2维数据

dim=0:表示在第0维进行连接,相当于在通道维度上进行组合(输入张量为两维,输出张量为三维),注意:此处输入张量维度为二维,因此dim最大只能为2。

  1. import torch
  2. a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  3. b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
  4. print(a)
  5. print(b)
  6. #在第0维进行连接,相当于在通道维度上进行组合
  7. #即取a的所有数据,作为新tensor的一个分量
  8. #取b的所有数据,作为新tensor的另一个分量
  9. #(输入张量为两维,输出张量为三维)
  10. c = torch.stack([a, b],dim=0)
  11. print(c.size())
  12. print(c)
  13. 输出:
  14. tensor([[1, 2, 3],
  15. [4, 5, 6],
  16. [7, 8, 9]])
  17. tensor([[11, 22, 33],
  18. [44, 55, 66],
  19. [77, 88, 99]])
  20. torch.Size([2, 3, 3])
  21. tensor([[[ 1, 2, 3],
  22. [ 4, 5, 6],
  23. [ 7, 8, 9]],
  24. [[11, 22, 33],
  25. [44, 55, 66],
  26. [77, 88, 99]]])

dim=1:表示在第1维进行连接,相当于对相应通道中每个行进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。

  1. import torch
  2. a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  3. b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
  4. print(a)
  5. print(b)
  6. #在第1维(行)进行连接,相当于对相应通道中每个行进行组合
  7. #取a的一行,b的一行,作为新tensor的第1行和第2行
  8. #原来a:3*3,b:3*3,新tensor:3*2*3
  9. c = torch.stack([a, b], 1)
  10. print(c.size())
  11. print(c)
  12. 输出:
  13. tensor([[1, 2, 3],
  14. [4, 5, 6],
  15. [7, 8, 9]])
  16. tensor([[11, 22, 33],
  17. [44, 55, 66],
  18. [77, 88, 99]])
  19. torch.Size([3, 2, 3])
  20. tensor([[[ 1, 2, 3],
  21. [11, 22, 33]],
  22. [[ 4, 5, 6],
  23. [44, 55, 66]],
  24. [[ 7, 8, 9],
  25. [77, 88, 99]]])

dim=2:表示在第2维进行连接,相当于对相应行中每个列元素进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。

  1. import torch
  2. a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  3. b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
  4. print(a)
  5. print(b)
  6. #在第2维进行连接,相当于对相应行中每个列元素进行组合
  7. #针对每行,取a、b的第一列数据,构成tensor的第一行
  8. #针对每行,取a、b的第二列数据,构成tensor的第二行
  9. #,针对每行取a、b的第三列数据,构成tensor的第三行
  10. #原来a:3*3,b:3*3,新tensor:3*3*2
  11. c = torch.stack([a, b], 2)
  12. print(c.size())
  13. print(c)
  14. 输出:
  15. tensor([[1, 2, 3],
  16. [4, 5, 6],
  17. [7, 8, 9]])
  18. tensor([[11, 22, 33],
  19. [44, 55, 66],
  20. [77, 88, 99]])
  21. torch.Size([3, 3, 2])
  22. tensor([[[ 1, 11],
  23. [ 2, 22],
  24. [ 3, 33]],
  25. [[ 4, 44],
  26. [ 5, 55],
  27. [ 6, 66]],
  28. [[ 7, 77],
  29. [ 8, 88],
  30. [ 9, 99]]])

情况三:输入数据为3维数据

dim=0:表示在第0维进行连接,相当于在通道维进行拼接。注意:此处输入张量维度为三维,因此dim最大只能为3。

  1. import torch
  2. a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])
  3. b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])
  4. print(a)
  5. print(b)
  6. #表示在第0维进行连接,取整个a作为新tensor的一个分量,取整个b作为新tensor的一个分量
  7. c = torch.stack([a, b], 0)
  8. print(c)
  9. 输出:
  10. tensor([[[ 1, 2, 3],
  11. [ 4, 5, 6],
  12. [ 7, 8, 9]],
  13. [[10, 20, 30],
  14. [40, 50, 60],
  15. [70, 80, 90]]])
  16. torch.Size([2, 3, 3])
  17. tensor([[[ 11, 22, 33],
  18. [ 44, 55, 66],
  19. [ 77, 88, 99]],
  20. [[110, 220, 330],
  21. [440, 550, 660],
  22. [770, 880, 990]]])
  23. torch.Size([2, 3, 3])
  24. torch.Size([2, 2, 3, 3])
  25. tensor([[[[ 1, 2, 3],
  26. [ 4, 5, 6],
  27. [ 7, 8, 9]],
  28. [[ 10, 20, 30],
  29. [ 40, 50, 60],
  30. [ 70, 80, 90]]],
  31. [[[ 11, 22, 33],
  32. [ 44, 55, 66],
  33. [ 77, 88, 99]],
  34. [[110, 220, 330],
  35. [440, 550, 660],
  36. [770, 880, 990]]]])

dim=1:表示在第1维进行连接,取各自的第1维度数据,进行拼接。注意:此处输入张量维度为三维,因此dim最大只能为3。 

  1. import torch
  2. a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])
  3. b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])
  4. print(a)
  5. print(a.size())
  6. print(b)
  7. print(b.size())
  8. #表示在第1维进行连接,取a的第一维数据[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
  9. #取b的第一维数据[[11, 22, 33], [44, 55, 66], [77, 88, 99]]作为新tensor的一个分量
  10. #取a的第一维数据[[10, 20, 30], [40, 50, 60], [70, 80, 90]]
  11. #取b的第一维数据[[110, 220, 330], [440, 550, 660], [770, 880, 990]]作为新tensor的另一个分量
  12. c = torch.stack([a, b], 1)
  13. print(c.size())
  14. print(c)
  15. 输出:
  16. tensor([[[ 1, 2, 3],
  17. [ 4, 5, 6],
  18. [ 7, 8, 9]],
  19. [[10, 20, 30],
  20. [40, 50, 60],
  21. [70, 80, 90]]])
  22. torch.Size([2, 3, 3])
  23. tensor([[[ 11, 22, 33],
  24. [ 44, 55, 66],
  25. [ 77, 88, 99]],
  26. [[110, 220, 330],
  27. [440, 550, 660],
  28. [770, 880, 990]]])
  29. torch.Size([2, 3, 3])
  30. torch.Size([2, 2, 3, 3])
  31. tensor([[[[ 1, 2, 3],
  32. [ 4, 5, 6],
  33. [ 7, 8, 9]],
  34. [[ 11, 22, 33],
  35. [ 44, 55, 66],
  36. [ 77, 88, 99]]],
  37. [[[ 10, 20, 30],
  38. [ 40, 50, 60],
  39. [ 70, 80, 90]],
  40. [[110, 220, 330],
  41. [440, 550, 660],
  42. [770, 880, 990]]]])

dim=2:表示在第2维进行连接,取各自的第2维度数据,进行拼接。注意:此处输入张量维度为三维,因此dim最大只能为3。

  1. import torch
  2. a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])
  3. b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])
  4. print(a)
  5. print(a.size())
  6. print(b)
  7. print(b.size())
  8. #表示在第1维进行连接,取a的第2维数据[1, 2, 3]
  9. #取b的第2维数据[11, 22, 33]作为新tensor的一个分量
  10. #取a的第2维数据[4, 5, 6]
  11. #取b的第2维数据[44, 55, 66]作为新tensor的一个分量
  12. #取a的第2维数据[4, 5, 6]
  13. #取b的第2维数据[44, 55, 66]作为新tensor的一个分量
  14. #取a的第2维数据[7, 8, 9]
  15. #取b的第2维数据[77, 88, 99]作为新tensor的一个分量
  16. #取a的第2维数据[10, 20, 30]
  17. #取b的第2维数据[110, 220, 330]作为新tensor的一个分量
  18. #取a的第2维数据[40, 50, 60]
  19. #取b的第2维数据[440, 550, 660]作为新tensor的一个分量
  20. #取a的第2维数据[70, 80, 90]
  21. #取b的第2维数据[770, 880, 990]作为新tensor的一个分量
  22. c = torch.stack([a, b], 2)
  23. print(c.size())
  24. print(c)
  25. 输出:
  26. tensor([[[ 1, 2, 3],
  27. [ 4, 5, 6],
  28. [ 7, 8, 9]],
  29. [[10, 20, 30],
  30. [40, 50, 60],
  31. [70, 80, 90]]])
  32. torch.Size([2, 3, 3])
  33. tensor([[[ 11, 22, 33],
  34. [ 44, 55, 66],
  35. [ 77, 88, 99]],
  36. [[110, 220, 330],
  37. [440, 550, 660],
  38. [770, 880, 990]]])
  39. torch.Size([2, 3, 3])
  40. torch.Size([2, 3, 2, 3])
  41. tensor([[[[ 1, 2, 3],
  42. [ 11, 22, 33]],
  43. [[ 4, 5, 6],
  44. [ 44, 55, 66]],
  45. [[ 7, 8, 9],
  46. [ 77, 88, 99]]],
  47. [[[ 10, 20, 30],
  48. [110, 220, 330]],
  49. [[ 40, 50, 60],
  50. [440, 550, 660]],
  51. [[ 70, 80, 90],
  52. [770, 880, 990]]]])

dim=3:表示在第3维进行连接,取各自的第3维度数据,进行拼接。注意:此处输入张量维度为三维,因此dim最大只能为3。

  1. import torch
  2. a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])
  3. b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])
  4. print(a)
  5. print(a.size())
  6. print(b)
  7. print(b.size())
  8. #针对第二维数据,在每个第二维度相同的情况下,取各自的列数据,构成新tensor的一行
  9. c = torch.stack([a, b], 3)
  10. print(c.size())
  11. print(c)
  12. 输出:
  13. tensor([[[ 1, 2, 3],
  14. [ 4, 5, 6],
  15. [ 7, 8, 9]],
  16. [[10, 20, 30],
  17. [40, 50, 60],
  18. [70, 80, 90]]])
  19. torch.Size([2, 3, 3])
  20. tensor([[[ 11, 22, 33],
  21. [ 44, 55, 66],
  22. [ 77, 88, 99]],
  23. [[110, 220, 330],
  24. [440, 550, 660],
  25. [770, 880, 990]]])
  26. torch.Size([2, 3, 3])
  27. torch.Size([2, 3, 3, 2])
  28. tensor([[[[ 1, 11],
  29. [ 2, 22],
  30. [ 3, 33]],
  31. [[ 4, 44],
  32. [ 5, 55],
  33. [ 6, 66]],
  34. [[ 7, 77],
  35. [ 8, 88],
  36. [ 9, 99]]],
  37. [[[ 10, 110],
  38. [ 20, 220],
  39. [ 30, 330]],
  40. [[ 40, 440],
  41. [ 50, 550],
  42. [ 60, 660]],
  43. [[ 70, 770],
  44. [ 80, 880],
  45. [ 90, 990]]]])

总结:m个序列数据,在某个维度k进行拼接,该维度大小为n,则拼接后形成了*n*m*大小,具体拼接过程是取m个序列数据,k-1维(设k-1维大小为x,从x=1开始取)相同情况下的第1个数据,构成新tensor的一个行;第二个数据...,第三个数据...构成tensor的新行;然后从x=2开始执行同样的操作

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/317104
推荐阅读
相关标签
  

闽ICP备14008679号