jax.numpy.union1d#
- jax.numpy.union1d(ar1, ar2, *, size=None, fill_value=None)[source]#
Find the union of two arrays.
LAX-backend implementation of
numpy.union1d().Because the size of the output of
union1dis data-dependent, the function is not typically compatible with JIT. The JAX version adds the optionalsizeargument which must be specified statically forjnp.union1dto be used within some of JAX’s transformations.Original docstring below.
Return the unique, sorted array of values that are in either of the two input arrays.
- Parameters:
ar1 (array_like) – Input arrays. They are flattened if they are not already 1D.
ar2 (array_like) – Input arrays. They are flattened if they are not already 1D.
size (int, optional) – If specified, the first
sizeelements of the result will be returned. If there are fewer elements thansizeindicates, the return value will be padded withfill_value.fill_value (array_like, optional) – When
sizeis specified and there are fewer than the indicated number of elements, the remaining elements will be filled withfill_value, which defaults to the minimum value of the union.
- Returns:
union1d – Unique, sorted union of the input arrays.
- Return type:
ndarray