当前位置:   article > 正文

Pytorch中torch.nn.Softmax的dim参数含义

nn.softmax

涉及到多维tensor时,对softmax的参数dim总是很迷,下面用一个例子说明

import torch.nn as nn

m = nn.Softmax(dim=0)
n = nn.Softmax(dim=1)
k = nn.Softmax(dim=2)
input = torch.randn(2, 2, 3)
print(input)
print(m(input))
print(n(input))
print(k(input))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

输出:
input

tensor([[[ 0.5450, -0.6264,  1.0446],
         [ 0.6324,  1.9069,  0.7158]],

        [[ 1.0092,  0.2421, -0.8928],
         [ 0.0344,  0.9723,  0.4328]]])
  • 1
  • 2
  • 3
  • 4
  • 5

dim=0

tensor([[[0.3860, 0.2956, 0.8741],
         [0.6452, 0.7180, 0.5703]],

        [[0.6140, 0.7044, 0.1259],
         [0.3548, 0.2820, 0.4297]]])
  • 1
  • 2
  • 3
  • 4
  • 5

dim=0时,在第0维上sum=1,即:
[0][0][0]+[1][0][0]=0.3860+0.6140=1
[0][0][1]+[1][0][1]=0.2956+0.7044=1
… …

dim=1

tensor([[[0.4782, 0.0736, 0.5815],
         [0.5218, 0.9264, 0.4185]],

        [[0.7261, 0.3251, 0.2099],
         [0.2739, 0.6749, 0.7901]]])
  • 1
  • 2
  • 3
  • 4
  • 5

dim=1时,在第1维上sum=1,即:
[0][0][0]+[0][1][0]=0.4782+0.5218=1
[0][0][1]+[0][1][1]=0.0736+0.9264=1
… …

dim=2

tensor([[[0.3381, 0.1048, 0.5572],
         [0.1766, 0.6315, 0.1919]],

        [[0.6197, 0.2878, 0.0925],
         [0.1983, 0.5065, 0.2953]]])
  • 1
  • 2
  • 3
  • 4
  • 5

dim=2时,在第2维上sum=1,即:
[0][0][0]+[0][0][1]+[0][0][2]=0.3381+0.1048+0.5572=1.0001(四舍五入问题)
[0][1][0]+[0][1][1]+[0][1][2]=0.1766+0.6315+0.1919=1
… …

用图表示223的张量如下:
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/196968
推荐阅读
相关标签
  

闽ICP备14008679号