Source code for easydel.utils.rngs_utils
# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for JAX."""
from typing import Tuple, Union
from jax import random as jrandom
[docs]def next_rng(
*args,
**kwargs,
) -> Union[jrandom.PRNGKey, Tuple[jrandom.PRNGKey, ...], dict]:
"""Provides access to the global JaxRNG and splits the key based on arguments.
This function wraps the global `jax_utils_rng` instance and calls its `__call__` method,
passing through any arguments provided. This provides a convenient way to access and
split the global random number generator key.
Args:
*args: Positional arguments passed to the `jax_utils_rng` instance's `__call__` method.
**kwargs: Keyword arguments passed to the `jax_utils_rng` instance's `__call__` method.
Returns:
The split PRNGKey(s) from the global `jax_utils_rng` instance.
"""
global jax_utils_rng
return jax_utils_rng(*args, **kwargs)
[docs]class JaxRNG:
"""A wrapper around JAX's PRNGKey that simplifies key splitting."""
def __init__(self, rng: jrandom.PRNGKey):
"""Initializes the JaxRNG with a PRNGKey.
Args:
rng: A JAX PRNGKey.
"""
self.rng = rng
[docs] @classmethod
def from_seed(cls, seed: int) -> "JaxRNG":
"""Creates a JaxRNG instance from a seed.
Args:
seed: The seed to use for the random number generator.
Returns:
A JaxRNG instance.
"""
return cls(jrandom.PRNGKey(seed))
def __call__(
self, keys: Union[int, Tuple[str, ...]] = None
) -> Union[jrandom.PRNGKey, Tuple[jrandom.PRNGKey, ...], dict]:
"""Splits the internal PRNGKey and returns new keys.
Args:
keys: If None, returns a single split key and updates the internal RNG.
If an int, splits the key into `keys + 1` parts, updates the internal RNG,
and returns the last `keys` parts as a tuple.
If a tuple of strings, splits the key into `len(keys) + 1` parts,
updates the internal RNG, and returns a dictionary mapping the strings
to their corresponding key parts.
Returns:
The split PRNGKey(s) based on the `keys` argument.
"""
if keys is None:
self.rng, split_rng = jrandom.split(self.rng)
return split_rng
elif isinstance(keys, int):
split_rngs = jrandom.split(self.rng, num=keys + 1)
self.rng = split_rngs[0]
return tuple(split_rngs[1:])
else:
split_rngs = jrandom.split(self.rng, num=len(keys) + 1)
self.rng = split_rngs[0]
return {key: val for key, val in zip(keys, split_rngs[1:])} # noqa:B905
[docs]class GenerateRNG:
"""An infinite generator of JAX PRNGKeys, useful for iterating over seeds."""
def __init__(self, seed: int = 0):
"""Initializes the generator with a starting seed.
Args:
seed: The seed to use for the initial PRNGKey.
"""
self.seed = seed
self._rng = jrandom.PRNGKey(seed)
def __next__(self) -> jrandom.PRNGKey:
"""Generates and returns the next PRNGKey in the sequence.
Returns:
The next PRNGKey derived from the internal state.
"""
self._rng, key = jrandom.split(self._rng)
return key
@property
def rng(self) -> jrandom.PRNGKey:
"""Provides access to the next PRNGKey without advancing the generator.
Returns:
The next PRNGKey in the sequence.
"""
return next(self)