jax.lax.map#
- jax.lax.map(f, xs)[source]#
Map a function over leading array axes.
Like Python’s builtin map, except inputs and outputs are in the form of stacked arrays. Consider using the
vmap()transform instead, unless you need to apply a function element by element for reduced memory usage or heterogeneous computation with other control flow primitives.When
xsis an array type, the semantics ofmap()are given by this Python implementation:def map(f, xs): return np.stack([f(x) for x in xs])
Like
scan(),map()is implemented in terms of JAX primitives so many of the same advantages over a Python loop apply:xsmay be an arbitrary nested pytree type, and the mapped computation is compiled only once.- Parameters:
f – a Python function to apply element-wise over the first axis or axes of
xs.xs – values over which to map along the leading axis.
- Returns:
Mapped values.