当前位置:   article > 正文

Pytorch: torch.mean()

Pytorch: torch.mean()

PyTorch中,函数 torch.mean 用于计算张量的平均值(均值)。其可以对整个张量计算平均值,也可以沿某个或多个维度计算平均值。这个操作对于正则化数据、在神经网络中进行层间规范化等场合特别有用。

例子如下:

import torch

# 创建一个张量
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])

# 计算所有元素的平均值
mean_val = torch.mean(x)
print(mean_val)  # 输出: 2.5

# 按列计算平均值(沿着第0维)
mean_col = torch.mean(x, dim=0)
print(mean_col)  # 输出: tensor([2., 3.])

# 按行计算平均值(沿着第1维)
mean_row = torch.mean(x, dim=1)
print(mean_row)  # 输出: tensor([1.5, 3.5])

# 计算平均值时保持维度
mean_val_keepdim = torch.mean(x, dim=1, keepdim=True)
print(mean_val_keepdim)  # 输出: tensor([[1.5], [3.5]])

# 计算多维度的平均值
x_multi_dim = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])
mean_multi_dim = torch.mean(x_multi_dim, dim=(1, 2))
print(mean_multi_dim)  # 输出: tensor([2.5, 6.5])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

在这个例子中,torch.mean 使用了几个参数:

  • x:要计算均值的输入张量。
  • dim:指定计算均值的维度。可以是一个整数或一个整数元组。如果未指定,将计算所有元素的平均值。
  • keepdim:当设置为 True 时,输出张量的维度会与原张量保持一致,否则,求均值运算的维度将被降维。默认值为 False。

torch.mean 在训练深度学习模型时非常有用,因为它可以用来归一化层的激活值或是计算损失函数等。

dim 参数在很多操作函数中用到,它表示张量中的一个或多个维度。当你对张量进行操作,如计算平均值、总和、最大值、最小值等,你可以指定 dim 参数来选择沿哪个维度进行计算。
在多维张量的情况下,dim 的理解至关重要:

  • 0维(dim=0)通常是指向“沿着列”的操作,或者说是“沿着批次(batch)的大小”的方向。如果你有一个形状为 (n, m)的二维张量,对这个张量做 dim=0 的操作,会对每列中的 n 个元素进行计算。结果会是一个长度为 m的一维张量,每个元素都是之前那一列的计算结果。

  • 1维(dim=1)对于二维张量而言是指向“沿着行”的操作。如果有一个形状 (n, m) 的张量, dim=1 的操作会对每行中的 m 个元素进行计算。结果是一个长度为 n 的一维张量,每个元素是之前那一行的计算结果。

对于更高维度的张量, dim 参数的理解方式与前面类似,你可以将 dim 理解为张量的轴,这些轴从0开始编号,对应于张量的不同维度。
在使用 dim 参数时,你可以传入一个整数来指定单一维度,或者传入一个整数元组来指定多个维度。指定多个维度时,会同时沿着所有选中的维度进行操作。

下面是一些关于 dim 参数使用的例子:

import torch

# 创建一个3x4的2维张量
x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])

# 计算整个张量的平均值
mean_all = torch.mean(x)  # 不指定dim,计算所有元素

# 计算每一列的平均值,沿着0维进行
mean_dim0 = torch.mean(x, dim=0)

# 计算每一行的平均值,沿着1维进行
mean_dim1 = torch.mean(x, dim=1)

# 创建一个2x3x4的3维张量
y = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], 
                  [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]])

# 计算每个2x4矩阵的平均值,沿着1维进行
mean_y_dim1 = torch.mean(y, dim=1)

# 计算沿着0维和2维的平均值,这会把每个1x3x1的切片平均
mean_y_dim02 = torch.mean(y, dim=(0, 2))

# 打印所有结果
print("mean_all:", mean_all.item())  # 输出单个标量值
print("mean_dim0:", mean_dim0)       # 输出一维张量,长度为列数
print("mean_dim1:", mean_dim1)       # 输出一维张量,长度为行数
print("mean_y_dim1:", mean_y_dim1)   # 输出2x4两个平均值
print("mean_y_dim02:", mean_y_dim02) # 输出长度为3的一维张量,是每个 "管道" 的平均值
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/289665
推荐阅读
相关标签
  

闽ICP备14008679号