当前位置:   article > 正文

torch.mean()方法

torch.mean

一、方法详解

含义:顾名思义,返回一个Tensor的均值

torch.mean(input, dim, keepdim=False)

 input:输入的张量
 dim:求均值的维度,若dim = 0,则沿行求均值,返回的形状是(1,列数);若dim=1,则沿列  求均值,返回的形状是(行数,1),默认不设置dim的时候,返回的是所有元素的平均值。
 keepdim:输出张量是否跟输入张量的另一个维度相同

需要注意的是 mean()函数只能在float格式的数据上处理

如果不是,要不就x = x.float()

或者定义的时候 dtype=torch.float

这个用言语很难说明白,我们直接通过案例来理解掌握!

  1. import torch
  2. from torch.autograd import Variable # torch 中 Variable 模块
  3. tensor = torch.FloatTensor([[1,2],[3,4]])
  4. print(tensor)
  5. out = tensor*tensor
  6. t_out = torch.mean(tensor*tensor,dim=0,keepdim=True)
  7. v_out = torch.mean(tensor*tensor,dim=1,keepdim=True)
  8. m_out = torch.mean(tensor*tensor)
  9. print(out)
  10. print(out.size())
  11. print(t_out)
  12. print(t_out.size())
  13. print(v_out)
  14. print(v_out.size())
  15. print(m_out)
  16. print(m_out.size())
  1. tensor([[1., 2.],
  2. [3., 4.]])
  3. tensor([[ 1., 4.],
  4. [ 9., 16.]])
  5. torch.Size([2, 2])
  6. tensor([[ 5., 10.]])
  7. torch.Size([1, 2])
  8. tensor([[ 2.5000],
  9. [12.5000]])
  10. torch.Size([2, 1])
  11. tensor(7.5000)
  12. torch.Size([])

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/441255
推荐阅读
相关标签
  

闽ICP备14008679号