赞
踩
对一个三维数组的每一维度进行操作
1,dim=0
- a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7]).view(2, 2, 2)
- print(a)
- mean = torch.mean(a, 0)
- print(mean, mean.shape)
输出结果:
tensor([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
tensor([[2., 3.],
[4., 5.]]) torch.Size([2, 2])
2,dim=1
- a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7]).view(2, 2, 2)
- print(a)
- mean = torch.mean(a, 1)
- print(mean, mean.shape)
输出结果
tensor(
[[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
tensor(
[[1., 2.],
[5., 6.]]) torch.Size([2, 2])
3,dim=2
- a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7]).view(2, 2, 2)
- print(a)
- mean = torch.mean(a, 2)
- print(mean, mean.shape)
输出结果
tensor(
[[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
tensor(
[[0.5000, 2.5000],
[4.5000, 6.5000]]) torch.Size([2, 2])
补充,如果在函数中添加了True,表示要和原来数的维度一致,不够的用维度1来添加,如下
-
- a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7]).view(2, 2, 2)
- print(a)
- mean = torch.mean(a, 2, True)
- print(mean, mean.shape)
- tensor([[[0., 1.],
- [2., 3.]],
-
- [[4., 5.],
- [6., 7.]]])
- tensor([[[0.5000],
- [2.5000]],
-
- [[4.5000],
- [6.5000]]]) torch.Size([2, 2, 1])
补充多维度变化
-
- a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15]).view(2, 2, 2,2)
- print(a)
- mean = torch.mean(a, 0, True)
- print(mean, mean.shape)
- tensor([[[[ 0., 1.],
- [ 2., 3.]],
-
- [[ 4., 5.],
- [ 6., 7.]]],
-
-
- [[[ 8., 9.],
- [10., 11.]],
-
- [[12., 13.],
- [14., 15.]]]])
- tensor([[[[ 4., 5.],
- [ 6., 7.]],
-
- [[ 8., 9.],
- [10., 11.]]]]) torch.Size([1, 2, 2, 2])
-
- a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15]).view(2, 2, 2,2)
- print(a)
- mean = torch.mean(a, 1, True)
- print(mean, mean.shape)
- tensor([[[[ 0., 1.],
- [ 2., 3.]],
-
- [[ 4., 5.],
- [ 6., 7.]]],
-
-
- [[[ 8., 9.],
- [10., 11.]],
-
- [[12., 13.],
- [14., 15.]]]])
- tensor([[[[ 2., 3.],
- [ 4., 5.]]],
-
-
- [[[10., 11.],
- [12., 13.]]]]) torch.Size([2, 1, 2, 2])
- a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15]).view(2, 2, 2,2)
- print(a)
- mean = torch.mean(a, 2, True)
- print(mean, mean.shape)
-
- tensor([[[[ 0., 1.],
- [ 2., 3.]],
-
- [[ 4., 5.],
- [ 6., 7.]]],
-
-
- [[[ 8., 9.],
- [10., 11.]],
-
- [[12., 13.],
- [14., 15.]]]])
- tensor([[[[ 1., 2.]],
-
- [[ 5., 6.]]],
-
-
- [[[ 9., 10.]],
-
- [[13., 14.]]]]) torch.Size([2, 2, 1, 2])
-
- a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15]).view(2, 2, 2,2)
- print(a)
- mean = torch.mean(a, 3, True)
- print(mean, mean.shape)
- tensor([[[[ 0., 1.],
- [ 2., 3.]],
-
- [[ 4., 5.],
- [ 6., 7.]]],
-
-
- [[[ 8., 9.],
- [10., 11.]],
-
- [[12., 13.],
- [14., 15.]]]])
- tensor([[[[ 0.5000],
- [ 2.5000]],
-
- [[ 4.5000],
- [ 6.5000]]],
-
-
- [[[ 8.5000],
- [10.5000]],
-
- [[12.5000],
- [14.5000]]]]) torch.Size([2, 2, 2, 1])
-
- a = torch.Tensor([0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15,0, 1, 2, 3, 4, 5,6,7,8,9,10,11,12,13,14,15]).view(2, 2, 2,2,2)
- print(a)
- mean = torch.mean(a, 3, True)
- print(mean, mean.shape)
- tensor([[[[[ 0., 1.],
- [ 2., 3.]],
-
- [[ 4., 5.],
- [ 6., 7.]]],
-
-
- [[[ 8., 9.],
- [10., 11.]],
-
- [[12., 13.],
- [14., 15.]]]],
-
-
-
- [[[[ 0., 1.],
- [ 2., 3.]],
-
- [[ 4., 5.],
- [ 6., 7.]]],
-
-
- [[[ 8., 9.],
- [10., 11.]],
-
- [[12., 13.],
- [14., 15.]]]]])
- tensor([[[[[ 1., 2.]],
-
- [[ 5., 6.]]],
-
-
- [[[ 9., 10.]],
-
- [[13., 14.]]]],
-
-
-
- [[[[ 1., 2.]],
-
- [[ 5., 6.]]],
-
-
- [[[ 9., 10.]],
-
- [[13., 14.]]]]]) torch.Size([2, 2, 2, 1, 2])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。