jax.tree_util.tree_unflatten#
- jax.tree_util.tree_unflatten(treedef, leaves)[source]#
Reconstructs a pytree from the treedef and the leaves.
The inverse of
tree_flatten().- Parameters:
- Return type:
- Returns:
The reconstructed pytree, containing the
leavesplaced in the structure described bytreedef.