jax.sharding module#
Classes#
- class jax.sharding.Sharding#
Describes how a
jax.Arrayis laid out across devices.- property addressable_devices: set[jaxlib.xla_extension.Device]#
The set of devices in the
Shardingthat are addressable by the current process.
- addressable_devices_indices_map(global_shape)[source]#
A mapping from addressable devices to the slice of array data each contains.
addressable_devices_indices_mapcontains that part ofdevice_indices_mapthat applies to the addressable devices.- Parameters:
global_shape (Shape) –
- Return type:
Mapping[Device, Index | None]
- property device_set: set[jaxlib.xla_extension.Device]#
The set of devices that this
Shardingspans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
- Parameters:
global_shape (Shape) –
- Return type:
Mapping[Device, Index | None]
- is_equivalent_to(other, ndim)[source]#
Returns
Trueif two shardings are equivalent.Two shardings are equivalent if they place the same logical array shards on the same devices.
For example, a
NamedShardingmay be equivalent to aPositionalShardingif both place the same shards of the array on the same devices.
- property is_fully_addressable: bool#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding.is_fully_addressableis equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- class jax.sharding.XLACompatibleSharding#
Bases:
ShardingA
Shardingthat describes shardings expressible to XLA.Subclasses of
XLACompatibleShardingwork with all JAX APIs and transformations that use XLA.- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
- is_equivalent_to(other, ndim)[source]#
Returns
Trueif two shardings are equivalent.Two shardings are equivalent if they place the same logical array shards on the same devices.
For example, a
NamedShardingmay be equivalent to aPositionalShardingif both place the same shards of the array on the same devices.- Parameters:
self (
XLACompatibleSharding) –other (
XLACompatibleSharding) –ndim (
int) –
- Return type:
- class jax.sharding.SingleDeviceSharding#
Bases:
XLACompatibleShardingA
Shardingthat places its data on a single device.- Parameters:
device – A single
Device.
Example
>>> single_device_sharding = jax.sharding.SingleDeviceSharding( ... jax.devices()[0])
- property device_set: set[jaxlib.xla_extension.Device]#
The set of devices that this
Shardingspans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
- property is_fully_addressable: bool#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding.is_fully_addressableis equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- class jax.sharding.NamedSharding#
Bases:
XLACompatibleShardingA
NamedShardingexpresses sharding using named axes.A
NamedShardingis a pair of aMeshof devices andPartitionSpecwhich describes how to shard an array across that mesh.A
Meshis a multidimensional NumPy array of JAX devices, where each axis of the mesh has a name, e.g.'x'or'y'.A
PartitionSpecis a tuple, whose elements can be aNone, a mesh axis, or a tuple of mesh axes. Each element describes how an input dimension is partitioned across zero or more mesh dimensions. For example,PartitionSpec('x', 'y')says that the first dimension of data is sharded acrossxaxis of the mesh, and the second dimension is sharded acrossyaxis of the mesh.The Distributed arrays and automatic parallelization (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) tutorial has more details and diagrams that explain how
MeshandPartitionSpecare used.- Parameters:
mesh – A
jax.sharding.Meshobject.spec – A
jax.sharding.PartitionSpecobject.
Example
>>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> spec = P('x', 'y') >>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
- property addressable_devices: set[jaxlib.xla_extension.Device]#
The set of devices in the
Shardingthat are addressable by the current process.
- property device_set: set[jaxlib.xla_extension.Device]#
The set of devices that this
Shardingspans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- property is_fully_addressable: bool#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding.is_fully_addressableis equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- class jax.sharding.PositionalSharding(devices, *, memory_kind=None)[source]#
Bases:
XLACompatibleShardingA sharding strategy that arranges data based on device positions.
This strategy enables efficient data distribution across devices, making it suitable for parallel processing on different hardware platforms.
Example
>>> devices = [xc.Device("GPU", 0), xc.Device("GPU", 1)] >>> sharding = PositionalSharding(devices) >>> print(sharding.shape) # Output: (2,) >>> print(sharding.ndim) # Output: 1
- Initialize a PositionalSharding instance with two GPU devices:
>>> devices = [xc.Device("GPU", 0), xc.Device("GPU", 1)] >>> sharding = PositionalSharding(devices)
- Parameters:
devices (Sequence[xc.Device] | np.ndarray) –
memory_kind (str | None) –
- property T: PositionalSharding#
Create a new PositionalSharding instance with a transposed data layout.
Transposing a sharding instance involves rearranging the dimensions of the data layout. This method allows you to change the order of dimensions for efficient data processing or compatibility with other operations.
- Parameters:
*axes – Axes permutation for the transpose operation as individual arguments.
Example
- Transpose a PositionalSharding instance, swapping the first and second dimensions:
>>> transposed_sharding = sharding.transpose(1, 0)
- Returns:
A new PositionalSharding instance with the transposed data layout.
- Return type:
- property device_set: set[jaxlib.xla_extension.Device]#
The set of devices that this
Shardingspans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- property is_fully_addressable: bool#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding.is_fully_addressableis equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- replicate(axis=None, keepdims=True)[source]#
Create a new PositionalSharding instance with replicated data along a specified axis.
Replicating a sharding instance involves creating additional copies of the data to distribute it along a specific axis. This method allows you to replicate the data within the sharding strategy, which can be useful for parallel processing or data redundancy.
- Parameters:
axis – The axis along which data replication is performed. If not specified, replication is applied across all dimensions.
keepdims – Whether to keep the dimensions of the sharding consistent or not. When set to True, the replicated dimensions are retained; otherwise, they are collapsed.
Example
- Replicate a PositionalSharding instance along the first axis while keeping dimensions:
>>> replicated_sharding = sharding.replicate(axis=0, keepdims=True)
- Returns:
A new PositionalSharding instance with the data replicated along the specified axis and dimensionality adjustments based on the ‘keepdims’ parameter.
- Return type:
- reshape(*shape)[source]#
Returns a new PositionalSharding instance with a reshaped data layout.
- Parameters:
*shape – New shape dimensions.
- Returns:
A new PositionalSharding instance.
- Return type:
Example
- Reshape a PositionalSharding instance:
>>> new_sharding = sharding.reshape(2, 1)
- transpose(*axes)[source]#
Create a new PositionalSharding instance with a transposed data layout.
Transposing a sharding instance involves rearranging the dimensions of the data layout. This method allows you to change the order of dimensions for efficient data processing or compatibility with other operations.
- Parameters:
*axes – Axes permutation for the transpose operation as individual arguments.
Example
- Transpose a PositionalSharding instance, swapping the first and second dimensions:
>>> transposed_sharding = sharding.transpose(1, 0)
- Returns:
A new PositionalSharding instance with the transposed data layout.
- Return type:
- with_memory_kind(kind)[source]#
Create a new PositionalSharding instance with a specified memory kind.
Memory kind refers to the type of memory or storage associated with the sharding strategy. This method allows you to customize the memory kind used by the sharding for efficient memory management and allocation.
- Parameters:
kind (str) – The memory kind to associate with the sharding. Common memory kinds include “HBM” (High-Bandwidth Memory) and “DDR” (Double Data Rate memory).
Example
Create a new PositionalSharding instance with a specific memory kind, such as “HBM” (High-Bandwidth Memory): >>> sharding_with_mem_kind = sharding.with_memory_kind(“HBM”) # doctest: +SKIP
- Returns:
A new PositionalSharding instance with the specified memory kind.
- Return type:
- class jax.sharding.PmapSharding#
Bases:
XLACompatibleShardingDescribes a sharding used by
jax.pmap().- classmethod default(shape, sharded_dim=0, devices=None)[source]#
Creates a
PmapShardingwhich matches the default placement used byjax.pmap().- Parameters:
shape (Shape) – The shape of the input array.
sharded_dim (int) – Dimension the input array is sharded on. Defaults to 0.
devices (Sequence[xc.Device] | None) – Optional sequence of devices to use. If omitted, the implicit
used (device order used by pmap is) –
jax.local_devices().of (which is the order) –
jax.local_devices().
- Return type:
- property device_set: set[jaxlib.xla_extension.Device]#
The set of devices that this
Shardingspans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
- is_equivalent_to(other, ndim)[source]#
Returns
Trueif two shardings are equivalent.Two shardings are equivalent if they place the same logical array shards on the same devices.
For example, a
NamedShardingmay be equivalent to aPositionalShardingif both place the same shards of the array on the same devices.- Parameters:
self (
PmapSharding) –other (
PmapSharding) –ndim (
int) –
- Return type:
- property is_fully_addressable: bool#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding.is_fully_addressableis equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- class jax.sharding.GSPMDSharding#
Bases:
XLACompatibleSharding- property device_set: set[jaxlib.xla_extension.Device]#
The set of devices that this
Shardingspans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
- property is_fully_addressable: bool#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding.is_fully_addressableis equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- class jax.sharding.PartitionSpec(*partitions)[source]#
Tuple describing how to partition an array across a mesh of devices.
Each element is either
None, a string, or a tuple of strings. See the documentation ofjax.sharding.NamedShardingfor more details.This class exists so JAX’s pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees.
- class jax.sharding.Mesh(devices: np.ndarray | Sequence[xc.Device], axis_names: str | Sequence[MeshAxisName])[source]#
Declare the hardware resources available in the scope of this manager.
In particular, all
axis_namesbecome valid resource names inside the managed block and can be used e.g. in thein_axis_resourcesargument ofjax.experimental.pjit.pjit(). Also see JAX’s multi-process programming model (https://jax.readthedocs.io/en/latest/multi_process.html) and the Distributed arrays and automatic parallelization tutorial (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)If you are compiling in multiple threads, make sure that the
with Meshcontext manager is inside the function that the threads will execute.- Parameters:
devices – A NumPy ndarray object containing JAX device objects (as obtained e.g. from
jax.devices()).axis_names – A sequence of resource axis names to be assigned to the dimensions of the
devicesargument. Its length should match the rank ofdevices.
Example
>>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> inp = np.arange(16).reshape((8, 2)) >>> devices = np.array(jax.devices()).reshape(4, 2) ... >>> # Declare a 2D mesh with axes `x` and `y`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> # Use the mesh object directly as a context manager. >>> with global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager. >>> with Mesh(devices, ('x', 'y')) as global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> with global_mesh as m: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`. >>> with Mesh(devices, ('x', 'y')): ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)