赞
踩
torch.where()有两种用法,
1.当输入参数为三个时,即torch.where(condition, x, y),返回满足 x if condition else y的tensor,注意x,y必须为tensor
2.当输入参数为一个时,即torch.where(condition),返回满足condition的tensor索引的元组(tuple)
代码
- import torch
- import numpy as np
-
- # 初始化两个tensor
- x = torch.tensor([
- [1,2,3,0,6],
- [4,6,2,1,0],
- [4,3,0,1,1]
- ])
- y = torch.tensor([
- [0,5,1,4,2],
- [5,7,1,2,9],
- [1,3,5,6,6]
- ])
-
- # 寻找满足x中大于3的元素,否则得到y对应位置的元素
- arr0 = torch.where(x>=3, x, y) #输入参数为3个
-
- print(x, '\n', y)
- print(arr0, '\n', type(arr0))
结果
- >>> x
- tensor([[1, 2, 3, 0, 6],
- [4, 6, 2, 1, 0],
- [4, 3, 0, 1, 1]])
- >>> y
- tensor([[0, 5, 1, 4, 2],
- [5, 7, 1, 2, 9],
- [1, 3, 5, 6, 6]])
-
- >>> arr0
- tensor([[0, 5, 3, 4, 6],
- [4, 6, 1, 2, 9],
- [4, 3, 5, 6, 6]])
-
- >>> type(arr0)
- <class 'torch.Tensor'>
arr0的类型为<class 'torch.Tensor'>
以寻找tensor中为0的索引为例
代码
- import torch
- import numpy as np
- x = torch.tensor([
- [1,2,3,0,6],
- [4,6,2,1,0],
- [4,3,0,1,1]
- ])
- y = torch.tensor([
- [0,5,1,4,2],
- [5,7,1,2,9],
- [1,3,5,6,6]
- ])
-
- # 返回x中0元素的索引
- index0 = torch.where(x==0) # 输入参数为1个
-
- print(index0,'\n', type(index0))
结果
- >>> index0
- (tensor([0, 1, 2]), tensor([3, 4, 2]))
-
- >>> type(index0)
- <class 'tuple'>
其中[0, 1, 2]是0元素坐标的行索引,[3, 4, 2]是0元素坐标的列索引,注意,最终得到的是tuple类型的返回值,元组中包含了tensor
np.where()用法与torch.where()用法类似,也包括两种用法,但是不同的是输入值类型和返回值的类型
代码
- import torch
- import numpy as np
- x = torch.tensor([
- [1,2,3,0,6],
- [4,6,2,1,0],
- [4,3,0,1,1]
- ])
- y = torch.tensor([
- [0,5,1,4,2],
- [5,7,1,2,9],
- [1,3,5,6,6]
- ])
-
- arr1 = np.where(x>=3, x, y) # 输入参数为3个
-
- index0 = torch.where(x==0) # 输入参数为1个
-
- print(arr1,'\n',type(arr1))
- print(index1,'\n', type(index1))
-
结果
- >>> arr1
- [[0 5 3 4 6]
- [4 6 1 2 9]
- [4 3 5 6 6]]
-
- >>> type(arr1)
- <class 'numpy.ndarray'>
-
- >>> index1
- (array([0, 1, 2]), array([3, 4, 2]))
-
- >>> type(index1)
- <class 'tuple'>
注意,np.where()和torch.where()的返回值类型不同
寻找符合contion的元素索引
代码
- import torch
- import numpy as np
- x = torch.tensor([
- [1,2,3,0,6],
- [4,6,2,1,0],
- [4,3,0,1,1]
- ])
- y = torch.tensor([
- [0,5,1,4,2],
- [5,7,1,2,9],
- [1,3,5,6,6]
- ])
-
-
- index2 = np.argwhere(x==0) # 寻找元素为0的索引
-
- print(index2,'\n', type(index2))
结果
- >>> index2
- tensor([[0, 1, 2],
- [3, 4, 2]])
-
- >>> type(index2)
- <class 'torch.Tensor'>
注意返回值的类型
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。