jax.lax.reshape#
- jax.lax.reshape(operand, new_sizes, dimensions=None)[source]#
Wraps XLA’s Reshape operator.
For inserting/removing dimensions of size 1, prefer using
lax.squeeze/lax.expand_dims. These preserve information about axis identity that may be useful for advanced transformation rules.- Parameters:
operand (
Union[Array,ndarray,bool_,number,bool,int,float,complex]) – array to be reshaped.new_sizes (
Sequence[Union[int,Any]]) – sequence of integers specifying the resulting shape. The size of the final array must match the size of the input.dimensions (
Optional[Sequence[int]]) – optional sequence of integers specifying the permutation order of the input shape. If specified, the length must matchoperand.shape.
- Returns:
reshaped array.
- Return type:
out
Examples
Simple reshaping from one to two dimensions:
>>> x = jnp.arange(6) >>> y = reshape(x, (2, 3)) >>> y Array([[0, 1, 2], [3, 4, 5]], dtype=int32)
Reshaping back to one dimension:
>>> reshape(y, (6,)) Array([0, 1, 2, 3, 4, 5], dtype=int32)
Reshaping to one dimension with permutation of dimensions:
>>> reshape(y, (6,), (1, 0)) Array([0, 3, 1, 4, 2, 5], dtype=int32)