jax.extend.random.threefry_prng_impl#

jax.extend.random.threefry_prng_impl = PRNGImpl(key_shape=(2,), seed=<function threefry_seed>, split=<function threefry_split>, random_bits=<function threefry_random_bits>, fold_in=<function threefry_fold_in>, name='threefry2x32', tag='fry')#

Specifies PRNG key shape and operations.

A PRNG implementation is determined by a key type K and a collection of functions that operate on such keys. The key type K is an array type with element type uint32 and shape specified by key_shape. The type signature of each operations is:

seed :: int[] -> K
fold_in :: K -> int[] -> K
split[shape] :: K -> K[*shape]
random_bits[shape, bit_width] :: K -> uint<bit_width>[*shape]

A PRNG implementation is adapted to an array-like object of keys K by the PRNGKeyArray class, which should be created via the random_seed function.