easydel.utils.rngs_utils#
Utility functions for JAX.
- class easydel.utils.rngs_utils.GenerateRNG(seed: int = 0)[source]#
Bases:
objectAn infinite generator of JAX PRNGKeys, useful for iterating over seeds.
- property rng: PRNGKey#
Provides access to the next PRNGKey without advancing the generator.
- Returns
The next PRNGKey in the sequence.
- class easydel.utils.rngs_utils.JaxRNG(rng: PRNGKey)[source]#
Bases:
objectA wrapper around JAX’s PRNGKey that simplifies key splitting.
- easydel.utils.rngs_utils.next_rng(*args, **kwargs) Union[PRNGKey, Tuple[PRNGKey, ...], dict][source]#
Provides access to the global JaxRNG and splits the key based on arguments.
This function wraps the global jax_utils_rng instance and calls its __call__ method, passing through any arguments provided. This provides a convenient way to access and split the global random number generator key.
- Parameters
*args – Positional arguments passed to the jax_utils_rng instance’s __call__ method.
**kwargs – Keyword arguments passed to the jax_utils_rng instance’s __call__ method.
- Returns
The split PRNGKey(s) from the global jax_utils_rng instance.