jax.pure_callback#
- jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=False, **kwargs)[source]#
Applies a functionally pure Python callable. Works under
jit()/pmap()/etc.pure_callbackenables calling a Python function in JIT-ed JAX functions. The inputcallbackwill be passed NumPy arrays in place of JAX arrays and should also return NumPy arrays. Execution takes place on CPU, like any Python+NumPy function.The callback is treated as functionally pure, meaning it has no side-effects and its output value depends only on its argument values. As a consequence, it is safe to be called multiple times (e.g. when transformed by
vmap()orpmap()), or not to be called at all when e.g. the output of a jit-decorated function has no data dependence on its value. Pure callbacks may also be reordered if data-dependence allows.When
pmap()-ed, the pure callback will be called several times (one on each axis of the map). When vmap-ed the behavior will depend on the value of thevectorizedkeyword argument. WhenvectorizedisTrue, the callback is assumed to obeyjax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs]). Therefore, the callback will be called directly on batched inputs (where the batch axes are the leading dimensions). Additionally, the callbacks should return outputs that have corresponding leading batch axes. If not vectorizedcallbackwill be mapped sequentially across the batched axis. For example, ifcallback = lambda x, y: np.matmul(x, y), then we are free to setvectorized=Truebecause thenp.matmulfunction handles arbitrary leading batch dimensions.- Parameters:
callback (Callable[..., Any]) – A Python callable. The callable will be passed PyTrees of NumPy arrays as arguments, and should return a PyTree of NumPy arrays that matches
result_shape_dtypes.result_shape_dtypes (Any) – A PyTree with leaves that are objects with
shapeanddtypeattributes which represent to the shapes and dtypes of the value ofcallbackapplied toargsandkwargs.*args (Any) – The positional arguments to the callback. Must be PyTrees of JAX types.
sharding (SingleDeviceSharding | None) – optional sharding that specifies the device from which the callback should be invoked.
vectorized (bool) – A boolean that indicates whether or not
callbackis vectorized, meaning it can handle arrays with additional leading dimensions. Ifvectorizedis True, when the callback is mapped via jax.vmap, it will be called directly on inputs with leading batch dimensions instead of executingcallbackon each mapped input individually. The callback should also return outputs batched across the leading axis. By default,vectorizedisFalse.**kwargs (Any) – The keyword arguments to the callback. Must be PyTrees of JAX types.
- Returns:
The value of
callback(*args, **kwargs).