赞
踩
(1)torch.argmax(input, dim=None, keepdim=False)
返回指定维度最大值的序号;
(2)dim
给定的定义是:the demention to reduce.也就是把dim
这个维度的,变成这个维度的最大值的index。
# -*- coding: utf-8 -*- """ Created on Fri Jan 7 15:05:09 2022 @author: 86493 """ import torch a=torch.tensor([ [ [1, 5, 5, 2], [9, -6, 2, 8], [-3, 7, -9, 1] ], [ [-1, 7, -5, 2], [9, 6, 2, 8], [3, 7, 9, 1] ]]) b=torch.argmax(a,dim=1) print(a) print(a.shape) print(b)
(1)这个例子,tensor(2, 3, 4)
,因为是dim=1
,即将第二维度去掉,变成tensor(2, 4)
,将每一个3x4数组,变成1x4数组。
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
如上所示的3×4矩阵,取每一列的最大值对应的下标,a[0]中第一列的最大值的行标为1, 第二列的最大值的行标为2,第三列的最大值行标为0,第4列的最大值行标为1,所以最后输出[1, 2, 0, 1],取每一列的最大值,结果为:
tensor([[[ 1, 5, 5, 2],
[ 9, -6, 2, 8],
[-3, 7, -9, 1]],
[[-1, 7, -5, 2],
[ 9, 6, 2, 8],
[ 3, 7, 9, 1]]])
torch.Size([2, 3, 4])
tensor([[1, 2, 0, 1],
[1, 0, 2, 1]])
(1)如果改成dim=2
,即将第三维去掉,即取每一行的最大值对应的下标,结果为tensor(2, 3)
。
import torch a=torch.tensor([ [ [1, 5, 5, 2], [9, -6, 2, 8], [-3, 7, -9, 1] ], [ [-1, 7, -5, 2], [9, 6, 2, 8], [3, 7, 9, 1] ]]) b=torch.argmax(a,dim=2) print(b) print(a.shape) """ tensor([[2, 0, 1], [1, 0, 2]]) torch.Size([2, 3, 4]) """
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。