当前位置:   article > 正文

torch.where(),np.where()的两种用法,以及np.argwhere()寻找张量(tensor)和数组中为0的索引

torch.where

1.torch.where()

torch.where()有两种用法,

1.当输入参数为三个时,即torch.where(condition, x, y),返回满足 x if condition else y的tensor,注意x,y必须为tensor

2.当输入参数为一个时,即torch.where(condition),返回满足condition的tensor索引的元组(tuple)

代码示例

torch.where(condition, x, y)

代码

  1. import torch
  2. import numpy as np
  3. # 初始化两个tensor
  4. x = torch.tensor([
  5. [1,2,3,0,6],
  6. [4,6,2,1,0],
  7. [4,3,0,1,1]
  8. ])
  9. y = torch.tensor([
  10. [0,5,1,4,2],
  11. [5,7,1,2,9],
  12. [1,3,5,6,6]
  13. ])
  14. # 寻找满足x中大于3的元素,否则得到y对应位置的元素
  15. arr0 = torch.where(x>=3, x, y) #输入参数为3个
  16. print(x, '\n', y)
  17. print(arr0, '\n', type(arr0))

结果

  1. >>> x
  2. tensor([[1, 2, 3, 0, 6],
  3. [4, 6, 2, 1, 0],
  4. [4, 3, 0, 1, 1]])
  5. >>> y
  6. tensor([[0, 5, 1, 4, 2],
  7. [5, 7, 1, 2, 9],
  8. [1, 3, 5, 6, 6]])
  9. >>> arr0
  10. tensor([[0, 5, 3, 4, 6],
  11. [4, 6, 1, 2, 9],
  12. [4, 3, 5, 6, 6]])
  13. >>> type(arr0)
  14. <class 'torch.Tensor'>

arr0的类型为<class 'torch.Tensor'>

torch.where(condition)

以寻找tensor中为0的索引为例

代码

  1. import torch
  2. import numpy as np
  3. x = torch.tensor([
  4. [1,2,3,0,6],
  5. [4,6,2,1,0],
  6. [4,3,0,1,1]
  7. ])
  8. y = torch.tensor([
  9. [0,5,1,4,2],
  10. [5,7,1,2,9],
  11. [1,3,5,6,6]
  12. ])
  13. # 返回x中0元素的索引
  14. index0 = torch.where(x==0) # 输入参数为1个
  15. print(index0,'\n', type(index0))

结果

  1. >>> index0
  2. (tensor([0, 1, 2]), tensor([3, 4, 2]))
  3. >>> type(index0)
  4. <class 'tuple'>

其中[0, 1, 2]是0元素坐标的行索引,[3, 4, 2]是0元素坐标的列索引,注意,最终得到的是tuple类型的返回值,元组中包含了tensor

2.np.where()

np.where()用法与torch.where()用法类似,也包括两种用法,但是不同的是输入值类型和返回值的类型

代码示例

np.where(condition, x, y)和np.where(condition),输入x,y可以为非tensor

代码

  1. import torch
  2. import numpy as np
  3. x = torch.tensor([
  4. [1,2,3,0,6],
  5. [4,6,2,1,0],
  6. [4,3,0,1,1]
  7. ])
  8. y = torch.tensor([
  9. [0,5,1,4,2],
  10. [5,7,1,2,9],
  11. [1,3,5,6,6]
  12. ])
  13. arr1 = np.where(x>=3, x, y) # 输入参数为3个
  14. index0 = torch.where(x==0) # 输入参数为1个
  15. print(arr1,'\n',type(arr1))
  16. print(index1,'\n', type(index1))

结果

  1. >>> arr1
  2. [[0 5 3 4 6]
  3. [4 6 1 2 9]
  4. [4 3 5 6 6]]
  5. >>> type(arr1)
  6. <class 'numpy.ndarray'>
  7. >>> index1
  8. (array([0, 1, 2]), array([3, 4, 2]))
  9. >>> type(index1)
  10. <class 'tuple'>

注意,np.where()和torch.where()的返回值类型不同

3.np.argwhere(condition)

寻找符合contion的元素索引

代码示例

代码

  1. import torch
  2. import numpy as np
  3. x = torch.tensor([
  4. [1,2,3,0,6],
  5. [4,6,2,1,0],
  6. [4,3,0,1,1]
  7. ])
  8. y = torch.tensor([
  9. [0,5,1,4,2],
  10. [5,7,1,2,9],
  11. [1,3,5,6,6]
  12. ])
  13. index2 = np.argwhere(x==0) # 寻找元素为0的索引
  14. print(index2,'\n', type(index2))

结果

  1. >>> index2
  2. tensor([[0, 1, 2],
  3. [3, 4, 2]])
  4. >>> type(index2)
  5. <class 'torch.Tensor'>

注意返回值的类型

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号