jax.numpy.where#
- jax.numpy.where(condition, x=None, y=None, *, size=None, fill_value=None)[source]#
Return elements chosen from x or y depending on condition.
LAX-backend implementation of
numpy.where().At present, JAX does not support JIT-compilation of the single-argument form of
jax.numpy.where()because its output shape is data-dependent. The three-argument form does not have a data-dependent shape and can be JIT-compiled successfully. Alternatively, you can use the optionalsizekeyword to statically specify the expected size of the output.Special care is needed when the
xoryinput tojax.numpy.where()could have a value of NaN. Specifically, when a gradient is taken withjax.grad()(reverse-mode differentiation), a NaN in eitherxorywill propagate into the gradient, regardless of the value ofcondition. More information on this behavior and workarounds is available in the JAX FAQ: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-whereOriginal docstring below.
Note
When only condition is provided, this function is a shorthand for
np.asarray(condition).nonzero(). Using nonzero directly should be preferred, as it behaves correctly for subclasses. The rest of this documentation covers only the case where all three arguments are provided.- Parameters:
condition (array_like, bool) – Where True, yield x, otherwise yield y.
x (array_like) – Values from which to choose. x, y and condition need to be broadcastable to some shape.
y (array_like) – Values from which to choose. x, y and condition need to be broadcastable to some shape.
size (int, optional) – Only referenced when
xandyareNone. If specified, the indices of the firstsizeelements of the result will be returned. If there are fewer elements thansizeindicates, the return value will be padded withfill_value.fill_value (array_like, optional) – When
sizeis specified and there are fewer than the indicated number of elements, the remaining elements will be filled withfill_value, which defaults to zero.
- Returns:
out – An array with elements from x where condition is True, and elements from y elsewhere.
- Return type:
ndarray