要想弄清楚np.where怎么使用,需要对矩阵中每个元素的位置表示方式有所了解,下面介绍一下它的两个主要用法
- np.where(condition, x, y);满足条件(condition),输出x,不满足输出y。
>>> aa = np.arange(10)
>>> np.where(aa,1,-1)
array([-1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) # 0为False,所以第一个输出-1
>>> np.where(aa > 5,1,-1)
array([-1, -1, -1, -1, -1, -1, 1, 1, 1, 1])
>>> np.where([[True,False], [True,True]], # 官网上的例子
[[1,2], [3,4]],
[[9,8], [7,6]])
array([[1, 8],
[3, 4]])
上面这个例子的条件为[[True,False], [True,False]],分别对应最后输出结果的四个值。第一个值从[1,9]中选,因为条件为True,所以是选1。第二个值从[2,8]中选,因为条件为False,所以选8,后面以此类推。类似的问题可以再看个例子:
>>> a = 10
>>> np.where([[a > 5,a > a = np.array([2,4,6,8,10])
>>> np.where(a > 5) # 返回索引
(array([2, 3, 4]),)
>>> a[np.where(a > 5)] # 等价于 a[a>5]
array([ 6, 8, 10])
>>> np.where([[0, 1], [1, 0]])
(array([0, 1]), array([1, 0]))
上面这个例子条件中[[0,1],[1,0]]的真值为两个1,各自的第一维坐标为[0,1],第二维坐标为[1,0] 。下面再举个三维的例子感受一下:
>>> a = np.arange(27).reshape(3,3,3)
>>> a
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]]])
>>> np.where(a > 5)
(array([0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
array([2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2]),
array([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]))
# 符合条件的元素为
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]]]
np.where会输出每个元素的对应的坐标,因为原数组有三维,所以tuple中有三个数组。
- 最后加上一段我自己为了方便理解某个项目里面的一段关于np.where的代码特意加的内容,主要为了自己看,和上文不构成直接关系(哈哈)
>>>prob=np.arange(0,9).reshape(3,3)
>>>prob
array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>>pick=np.where(prob>4)
>>>print(pick)
>>>print(pick[0],pick[0].shape)
>>>print(pick[1],pick[1].shape)
>>>print(type(pick))
(array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))
[1 2 2 2] (4,)
[2 0 1 2] (4,)
>>>roi=np.arange(0,36).reshape(4,3,3)
>>>roi
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]],
[[27, 28, 29],
[30, 31, 32],
[33, 34, 35]]])
>>>for i in range(4):
res=roi[i][pick]
print(res,res.shape)
[5 6 7 8] (4,)
[14 15 16 17] (4,)
[23 24 25 26] (4,)
[32 33 34 35] (4,)