easydel.utils.__init__#

class easydel.utils.__init__.DataClassArgumentParser(dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs: Any)[source]#

Bases: ArgumentParser

A subclass of argparse.ArgumentParser that automatically generates arguments based on dataclass type hints.

It supports additional argparse features (like sub-groups) and can also load configuration from dictionaries, JSON files, or YAML files.

dataclass_types: Iterable[DataClassType]#
parse_args_into_dataclasses(args: Optional[List[str]] = None, return_remaining_strings: bool = False, look_for_args_file: bool = True, args_filename: Optional[str] = None, args_file_flag: Optional[str] = None) Tuple[Any, ...][source]#

Parse command-line arguments into instances of the specified dataclass types.

Optionally, this method can also look for an external “.args” file or a command-line flag that points to one, and prepend its content to the command-line arguments.

Raises

ValueError – If there are any unknown arguments (and return_remaining_strings is False).

parse_dict(args: Dict[str, Any], allow_extra_keys: bool = False) Tuple[Any, ...][source]#

Parse a dictionary of configuration values into dataclass instances.

Parameters
  • args – Dictionary containing configuration values.

  • allow_extra_keys – If False, raises an exception if unknown keys are present.

parse_json_file(json_file: Union[str, PathLike], allow_extra_keys: bool = False) Tuple[Any, ...][source]#

Load a JSON file and parse it into dataclass instances.

parse_yaml_file(yaml_file: Union[str, PathLike], allow_extra_keys: bool = False) Tuple[Any, ...][source]#

Load a YAML file and parse it into dataclass instances.

class easydel.utils.__init__.GenerateRNG(seed: int = 0)[source]#

Bases: object

An infinite generator of JAX PRNGKeys, useful for iterating over seeds.

property rng: PRNGKey#

Provides access to the next PRNGKey without advancing the generator.

Returns

The next PRNGKey in the sequence.

class easydel.utils.__init__.JaxRNG(rng: PRNGKey)[source]#

Bases: object

A wrapper around JAX’s PRNGKey that simplifies key splitting.

classmethod from_seed(seed: int) JaxRNG[source]#

Creates a JaxRNG instance from a seed.

Parameters

seed – The seed to use for the random number generator.

Returns

A JaxRNG instance.

class easydel.utils.__init__.LazyModule(name: str, module_file: str, import_structure: Dict[FrozenSet[str], Dict[str, Set[str]]], module_spec: ModuleSpec = None, extra_objects: Dict[str, object] = None)[source]#

Bases: module

class easydel.utils.__init__.Timer(name)[source]#

Bases: object

elapsed_time(reset=True)[source]#
reset()[source]#
start()[source]#
stop()[source]#
class easydel.utils.__init__.Timers(use_wandb, tensorboard_writer: Any)[source]#

Bases: object

log(names, normalizer=1.0, reset=True)[source]#
timed(name, log=True, reset=True)[source]#
write(names, iteration, normalizer=1.0, reset=False)[source]#
easydel.utils.__init__.cache_compiles(tag: Optional[str] = None, static_argnames: Optional[List[str]] = None)[source]#
easydel.utils.__init__.capture_time()[source]#

A context manager that measures elapsed time.

Returns

Callable that returns the current elapsed time while the context is active, or the final elapsed time after the context exits.

Example

with capture_time() as get_time:

# Do some work print(f”Current time: {get_time()}”)

print(f”Final time: {get_time()}”)

easydel.utils.__init__.check_bool_flag(name: str, default: bool = True) bool[source]#
easydel.utils.__init__.cjit(fn: Callable, static_argnums: Optional[Tuple[str]] = None, static_argnames: Optional[Tuple[int]] = None)[source]#
easydel.utils.__init__.compile_function(func, func_input_args, func_input_kwargs, mesh=None, in_shardings=None, out_shardings=None, static_argnums=None, donate_argnums=None)[source]#

Compiles a JAX function with optional sharding and mesh configuration.

Parameters
  • func – The JAX function to compile.

  • func_input_args – Input arguments for the function.

  • func_input_kwargs – Input keyword arguments for the function.

  • mesh – tp.Optional JAX mesh for distributed execution.

  • in_shardings – tp.Optional input sharding specifications.

  • out_shardings – tp.Optional output sharding specifications.

  • static_argnums – Indices of static arguments.

  • donate_argnums – Indices of arguments to donate.

Returns

Compiled JAX function.

easydel.utils.__init__.get_cache_dir() Path[source]#
easydel.utils.__init__.get_logger(name: str, level: Optional[int] = None) LazyLogger[source]#

Function to create a lazy logger that only initializes when first used.

Parameters
  • name (str) – The name of the logger.

  • level (Optional[int]) – The logging level. Defaults to environment variable LOGGING_LEVEL_ED or “INFO”.

Returns

A lazy logger instance that initializes on first use.

Return type

LazyLogger

easydel.utils.__init__.is_package_available(package_name: str) bool[source]#

Checks if a package is available in the current Python environment.

Parameters

package_name – The name of the package to check (e.g., “numpy”).

Returns

True if the package is available, False otherwise.

easydel.utils.__init__.load_compiled_fn(path: Union[str, PathLike], prefix: Optional[str] = None)[source]#
easydel.utils.__init__.save_compiled_fn(path: Union[str, PathLike], fn: Any, prefix: Optional[str] = None)[source]#