The NumPy function np.where
provides a way to execute operations on an array under a certain condition:
import numpy as np
arr = np.array([1,2,3,4])
result = np.where(arr > 3, arr * 2, arr)
In the example above the np.where
function will multiply all the elements in the array which satisfy the condition: element >
3
by 2. The elements that do not satisfy the condition will be left untouched. The result
array holds now the values 1, 2, 3 and
8.
It is also possible to call np.where
with only the condition parameter set:
import numpy as np
arr = np.array([1,2,3,4])
result = np.where(arr > 2)
Even though this is perfectly valid code in NumPy, it may not yield the expected results.
When providing only the condition parameter to the np.where
function, it will behave as np.asarray(condition).nonzero()
or np.nonzero(condition)
. Both these functions provide a way to find the indices of the elements satisfying the condition passed as
parameter. Be mindful that np.asarray(condition).nonzero()
and np.nonzero(condition)
do not return the
values that satisfy the condition but only their indices. This means the result
variable now holds a
tuple with the first element being an array of all the indices where the condition arr > 2
was satisfied:
(array([2,3]),)
.
If the intention is to find the indices of the elements which satisfy a certain condition it is preferable to use the
np.asarray(condition).nonzero()
or np.nonzero(condition)
function instead.