赞
踩
x=torch.arange(15).view(5,3)
x=x.float()
x_mean=torch.mean(x,dim=0,keepdim=True)(表示每一列的平均数)
x_mean0=torch.mean(x,dim=1,keepdim=True)(表示每一行的平均数)
x_mean:6 7 8
x_mean0:
1
4
7
10
13