赞
踩
以shape为(2, 3, 3)的tensor为例,理解x.mean或torch.mean的用法
import torch
x = torch.arange(18, dtype=torch.float32).view(2, 3, 3)
dim_0 = x.mean(dim=0)
print(dim_0)
输出结果:
可以看出,dim=0时,在x的通道维度进行了求平均,以dim_0[0][0]为例,dim_0[0][0] = (x[0][0][0]+x[1][0][0])/2。
dim_1 = x.mean(dim=1)
print(dim_1)
输出结果:
可以看出,当dim=1时,在x的行方向(或者说高H方向)进行求平均,以dim_1[0][0]为例,dim_1[0][0] = (x[0][0][0]+x[0][1][0]+x[0][2][0])/3。
dim_2 = x.mean(dim=2)
print(dim_2)
可以看出,当dim=2时,在x的列方向(或者说宽W方向)进行求平均。
dim_1_2 = x.mean(dim=[1, 2])
print(dim_1_2)
输出结果:
在行方向求平均的结果上再沿列方向求平均
dim_0_1 = x.mean(dim=[0, 1])
print(dim_0_1)
输出结果:
![dim=0, 1
在沿通道方向求平均的结果上再沿行方向求平均
比较keepdim=False与keepdim=True时的输出结果:
可以看出,keepdim=True时维度保持不变
x.mean(dim)与torch.mean(x,dim)能够实现相同的效果。
以上均为个人理解,如有错误或不妥之处,欢迎大家批评指正!!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。