当前位置:   article > 正文

pytorch中torch.meshgrid()函数理解及举例说明

pytorch中torch.meshgrid()函数理解及举例说明

说明:

函数的功能是生成网格,可以用于生成坐标。

函数输入:

输入两个一维tensor数据,且两个tensor数据类型相同,也可以输入三个一维tensor数据

函数输出:

输出两个tensor数据(两个tensor的行数为第一个输入张量的元素个数,列数为第二个输入张量的元素个数)或者三个tensor数据(三个tensor第一维度大小为第一个输入张量的元素个数,第二维度大小为第二个输入张量的元素个数,第三维度为第三个输入张量元素个数)

报错:

当两个输入tensor数据类型不同或维度不是一维时会报错。

结果理解:

输入两个一维张量的元素个数分别为n1,n2,则输出两个张量是二维的,且行和列个数均为n1,n2,输出第一个张量行相同(对应第一个输入张量),输出第二个张量列相同(对应第二个输入张量),其中第一个输出张量填充第一个输入张量中的元素,各行元素相同;第二个输出张量填充第二个输入张量中的元素,各列元素相同

若输入是三个一维张量,元素个数分别为n1,n2,n3,则输出的三个张量都是三维的,且输出的三个张量的三个维度均相等,分别为n1,n2,n3。

输入为两个张量:

  1. import torch
  2. import torch.nn as nn
  3. a1 = torch.tensor([1,3])
  4. b1 = torch.tensor([2,4,6])
  5. x1,y1 = torch.meshgrid(a1,b1)
  6. print(x1)
  7. print(y1)
  8. 输出:
  9. tensor([[1, 1, 1],
  10. [3, 3, 3]])
  11. tensor([[2, 4, 6],
  12. [2, 4, 6]])

输入为三个张量:

  1. import torch
  2. import torch.nn as nn
  3. a2 = torch.tensor([1,3])
  4. b2 = torch.tensor([2,4,6])
  5. c2 = torch.tensor([7,8,9,10])
  6. x2,y2,z2 = torch.meshgrid(a2,b2,c2)
  7. print(x2)
  8. print(x2.shape)
  9. print(y2)
  10. print(y2.shape)
  11. print(z2)
  12. print(z2.shape)
  13. 输出:
  14. tensor([[[1, 1, 1, 1],
  15. [1, 1, 1, 1],
  16. [1, 1, 1, 1]],
  17. [[3, 3, 3, 3],
  18. [3, 3, 3, 3],
  19. [3, 3, 3, 3]]])
  20. torch.Size([2, 3, 4])
  21. tensor([[[2, 2, 2, 2],
  22. [4, 4, 4, 4],
  23. [6, 6, 6, 6]],
  24. [[2, 2, 2, 2],
  25. [4, 4, 4, 4],
  26. [6, 6, 6, 6]]])
  27. torch.Size([2, 3, 4])
  28. tensor([[[ 7, 8, 9, 10],
  29. [ 7, 8, 9, 10],
  30. [ 7, 8, 9, 10]],
  31. [[ 7, 8, 9, 10],
  32. [ 7, 8, 9, 10],
  33. [ 7, 8, 9, 10]]])
  34. torch.Size([2, 3, 4])

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/441229
推荐阅读
相关标签
  

闽ICP备14008679号