当前位置:   article > 正文

topk(num,dim)_topk=[1, 3]

topk=[1, 3]

topk(num,dim=1)

>>> output=torch.randn(3,4)
>>> output
tensor([[-1.9291,  1.4127, -2.2464,  0.8932],
        [-0.4483, -0.3458,  0.8384,  1.9580],
        [-0.5633, -2.2806,  0.6278,  1.3552]])
在行上取一个最大值
>>> topkv,topki=output.topk(1,1)
>>> topkv
tensor([[1.4127],
        [1.9580],
        [1.3552]])
>>> topki
tensor([[1],
        [3],
        [3]])
在行上取前两个最大值
>>> topkv,topki=output.topk(2,1)
>>> topkv
tensor([[1.4127, 0.8932],
        [1.9580, 0.8384],
        [1.3552, 0.6278]])
>>> topki
tensor([[1, 3],
        [3, 2],
        [3, 2]])


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

topk(num,dim=0)

>>> output=torch.randn(3,4)
>>> output
tensor([[-1.9291,  1.4127, -2.2464,  0.8932],
        [-0.4483, -0.3458,  0.8384,  1.9580],
        [-0.5633, -2.2806,  0.6278,  1.3552]])
在列上取一个最大值        
>>> topkv,topki=output.topk(1,0)
>>> topkv
tensor([[-0.4483,  1.4127,  0.8384,  1.9580]])
>>> topki
tensor([[1, 0, 1, 1]])
在列上取两个最大值
>>> topkv,topki=output.topk(2,0)
>>> topkv
tensor([[-0.4483,  1.4127,  0.8384,  1.9580],
        [-0.5633, -0.3458,  0.6278,  1.3552]])
>>> topki
tensor([[1, 0, 1, 1],
        [2, 1, 2, 2]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/一键难忘520/article/detail/737608
推荐阅读
相关标签
  

闽ICP备14008679号