当前位置:   article > 正文

torch.mean()的使用方法

torch.mean()的使用方法

对一个三维数组的每一维度进行操作

1,dim=0

  1. a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7]).view(2, 2, 2)
  2. print(a)
  3. mean = torch.mean(a, 0)
  4. print(mean, mean.shape)

输出结果:

tensor([[[0., 1.],

             [2., 3.]],

             [[4., 5.],

              [6., 7.]]])

tensor([[2., 3.],

            [4., 5.]]) torch.Size([2, 2])

2,dim=1

  1. a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7]).view(2, 2, 2)
  2. print(a)
  3. mean = torch.mean(a, 1)
  4. print(mean, mean.shape)

输出结果

tensor(

[[[0., 1.],

[2., 3.]],

[[4., 5.],

[6., 7.]]])

tensor(

[[1., 2.],

[5., 6.]]) torch.Size([2, 2])

3,dim=2

  1. a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7]).view(2, 2, 2)
  2. print(a)
  3. mean = torch.mean(a, 2)
  4. print(mean, mean.shape)

输出结果

tensor(

[[[0., 1.],

[2., 3.]],

[[4., 5.],

[6., 7.]]])

tensor(

[[0.5000, 2.5000],

[4.5000, 6.5000]]) torch.Size([2, 2])

补充,如果在函数中添加了True,表示要和原来数的维度一致,不够的用维度1来添加,如下

  1. a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7]).view(2, 2, 2)
  2. print(a)
  3. mean = torch.mean(a, 2, True)
  4. print(mean, mean.shape)
  5. tensor([[[0., 1.],
  6. [2., 3.]],
  7. [[4., 5.],
  8. [6., 7.]]])
  9. tensor([[[0.5000],
  10. [2.5000]],
  11. [[4.5000],
  12. [6.5000]]]) torch.Size([2, 2, 1])

补充多维度变化

  1. a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15]).view(2, 2, 2,2)
  2. print(a)
  3. mean = torch.mean(a, 0, True)
  4. print(mean, mean.shape)
  5. tensor([[[[ 0., 1.],
  6. [ 2., 3.]],
  7. [[ 4., 5.],
  8. [ 6., 7.]]],
  9. [[[ 8., 9.],
  10. [10., 11.]],
  11. [[12., 13.],
  12. [14., 15.]]]])
  13. tensor([[[[ 4., 5.],
  14. [ 6., 7.]],
  15. [[ 8., 9.],
  16. [10., 11.]]]]) torch.Size([1, 2, 2, 2])
  1. a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15]).view(2, 2, 2,2)
  2. print(a)
  3. mean = torch.mean(a, 1, True)
  4. print(mean, mean.shape)
  5. tensor([[[[ 0., 1.],
  6. [ 2., 3.]],
  7. [[ 4., 5.],
  8. [ 6., 7.]]],
  9. [[[ 8., 9.],
  10. [10., 11.]],
  11. [[12., 13.],
  12. [14., 15.]]]])
  13. tensor([[[[ 2., 3.],
  14. [ 4., 5.]]],
  15. [[[10., 11.],
  16. [12., 13.]]]]) torch.Size([2, 1, 2, 2])
  1. a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15]).view(2, 2, 2,2)
  2. print(a)
  3. mean = torch.mean(a, 2, True)
  4. print(mean, mean.shape)
  5. tensor([[[[ 0., 1.],
  6. [ 2., 3.]],
  7. [[ 4., 5.],
  8. [ 6., 7.]]],
  9. [[[ 8., 9.],
  10. [10., 11.]],
  11. [[12., 13.],
  12. [14., 15.]]]])
  13. tensor([[[[ 1., 2.]],
  14. [[ 5., 6.]]],
  15. [[[ 9., 10.]],
  16. [[13., 14.]]]]) torch.Size([2, 2, 1, 2])
  1. a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15]).view(2, 2, 2,2)
  2. print(a)
  3. mean = torch.mean(a, 3, True)
  4. print(mean, mean.shape)
  5. tensor([[[[ 0., 1.],
  6. [ 2., 3.]],
  7. [[ 4., 5.],
  8. [ 6., 7.]]],
  9. [[[ 8., 9.],
  10. [10., 11.]],
  11. [[12., 13.],
  12. [14., 15.]]]])
  13. tensor([[[[ 0.5000],
  14. [ 2.5000]],
  15. [[ 4.5000],
  16. [ 6.5000]]],
  17. [[[ 8.5000],
  18. [10.5000]],
  19. [[12.5000],
  20. [14.5000]]]]) torch.Size([2, 2, 2, 1])
  1. a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15,0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15]).view(2, 2, 2,2,2)
  2. print(a)
  3. mean = torch.mean(a, 3, True)
  4. print(mean, mean.shape)
  5. tensor([[[[[ 0., 1.],
  6. [ 2., 3.]],
  7. [[ 4., 5.],
  8. [ 6., 7.]]],
  9. [[[ 8., 9.],
  10. [10., 11.]],
  11. [[12., 13.],
  12. [14., 15.]]]],
  13. [[[[ 0., 1.],
  14. [ 2., 3.]],
  15. [[ 4., 5.],
  16. [ 6., 7.]]],
  17. [[[ 8., 9.],
  18. [10., 11.]],
  19. [[12., 13.],
  20. [14., 15.]]]]])
  21. tensor([[[[[ 1., 2.]],
  22. [[ 5., 6.]]],
  23. [[[ 9., 10.]],
  24. [[13., 14.]]]],
  25. [[[[ 1., 2.]],
  26. [[ 5., 6.]]],
  27. [[[ 9., 10.]],
  28. [[13., 14.]]]]]) torch.Size([2, 2, 2, 1, 2])

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/441225
推荐阅读
相关标签
  

闽ICP备14008679号