easydel.utils.rngs_utils

easydel.utils.rngs_utils#

Utility functions for JAX.

class easydel.utils.rngs_utils.GenerateRNG(seed: int = 0)[source]#

Bases: object

An infinite generator of JAX PRNGKeys, useful for iterating over seeds.

property rng: PRNGKey#

Provides access to the next PRNGKey without advancing the generator.

Returns

The next PRNGKey in the sequence.

class easydel.utils.rngs_utils.JaxRNG(rng: PRNGKey)[source]#

Bases: object

A wrapper around JAX’s PRNGKey that simplifies key splitting.

classmethod from_seed(seed: int) JaxRNG[source]#

Creates a JaxRNG instance from a seed.

Parameters

seed – The seed to use for the random number generator.

Returns

A JaxRNG instance.

easydel.utils.rngs_utils.next_rng(*args, **kwargs) Union[PRNGKey, Tuple[PRNGKey, ...], dict][source]#

Provides access to the global JaxRNG and splits the key based on arguments.

This function wraps the global jax_utils_rng instance and calls its __call__ method, passing through any arguments provided. This provides a convenient way to access and split the global random number generator key.

Parameters
  • *args – Positional arguments passed to the jax_utils_rng instance’s __call__ method.

  • **kwargs – Keyword arguments passed to the jax_utils_rng instance’s __call__ method.

Returns

The split PRNGKey(s) from the global jax_utils_rng instance.