赞
踩
一、方法详解
含义:顾名思义,返回一个Tensor的均值
torch.mean(input, dim, keepdim=False)
input:输入的张量
dim:求均值的维度,若dim = 0,则沿行求均值,返回的形状是(1,列数);若dim=1,则沿列 求均值,返回的形状是(行数,1),默认不设置dim的时候,返回的是所有元素的平均值。
keepdim:输出张量是否跟输入张量的另一个维度相同
需要注意的是 mean()函数只能在float格式的数据上处理
如果不是,要不就x = x.float()
或者定义的时候 dtype=torch.float
这个用言语很难说明白,我们直接通过案例来理解掌握!
- import torch
- from torch.autograd import Variable # torch 中 Variable 模块
-
- tensor = torch.FloatTensor([[1,2],[3,4]])
- print(tensor)
- out = tensor*tensor
- t_out = torch.mean(tensor*tensor,dim=0,keepdim=True)
- v_out = torch.mean(tensor*tensor,dim=1,keepdim=True)
- m_out = torch.mean(tensor*tensor)
- print(out)
- print(out.size())
- print(t_out)
- print(t_out.size())
- print(v_out)
- print(v_out.size())
- print(m_out)
- print(m_out.size())

- tensor([[1., 2.],
- [3., 4.]])
- tensor([[ 1., 4.],
- [ 9., 16.]])
- torch.Size([2, 2])
- tensor([[ 5., 10.]])
- torch.Size([1, 2])
- tensor([[ 2.5000],
- [12.5000]])
- torch.Size([2, 1])
- tensor(7.5000)
- torch.Size([])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。