easydel.utils.__init__#
- class easydel.utils.__init__.DataClassArgumentParser(dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs: Any)[source]#
Bases:
ArgumentParserA subclass of argparse.ArgumentParser that automatically generates arguments based on dataclass type hints.
It supports additional argparse features (like sub-groups) and can also load configuration from dictionaries, JSON files, or YAML files.
- dataclass_types: Iterable[DataClassType]#
- parse_args_into_dataclasses(args: Optional[List[str]] = None, return_remaining_strings: bool = False, look_for_args_file: bool = True, args_filename: Optional[str] = None, args_file_flag: Optional[str] = None) Tuple[Any, ...][source]#
Parse command-line arguments into instances of the specified dataclass types.
Optionally, this method can also look for an external “.args” file or a command-line flag that points to one, and prepend its content to the command-line arguments.
- Raises
ValueError – If there are any unknown arguments (and return_remaining_strings is False).
- parse_dict(args: Dict[str, Any], allow_extra_keys: bool = False) Tuple[Any, ...][source]#
Parse a dictionary of configuration values into dataclass instances.
- Parameters
args – Dictionary containing configuration values.
allow_extra_keys – If False, raises an exception if unknown keys are present.
- class easydel.utils.__init__.GenerateRNG(seed: int = 0)[source]#
Bases:
objectAn infinite generator of JAX PRNGKeys, useful for iterating over seeds.
- property rng: PRNGKey#
Provides access to the next PRNGKey without advancing the generator.
- Returns
The next PRNGKey in the sequence.
- class easydel.utils.__init__.JaxRNG(rng: PRNGKey)[source]#
Bases:
objectA wrapper around JAX’s PRNGKey that simplifies key splitting.
- class easydel.utils.__init__.LazyModule(name: str, module_file: str, import_structure: Dict[FrozenSet[str], Dict[str, Set[str]]], module_spec: ModuleSpec = None, extra_objects: Dict[str, object] = None)[source]#
Bases:
module
- easydel.utils.__init__.cache_compiles(tag: Optional[str] = None, static_argnames: Optional[List[str]] = None)[source]#
- easydel.utils.__init__.capture_time()[source]#
A context manager that measures elapsed time.
- Returns
Callable that returns the current elapsed time while the context is active, or the final elapsed time after the context exits.
Example
- with capture_time() as get_time:
# Do some work print(f”Current time: {get_time()}”)
print(f”Final time: {get_time()}”)
- easydel.utils.__init__.cjit(fn: Callable, static_argnums: Optional[Tuple[str]] = None, static_argnames: Optional[Tuple[int]] = None)[source]#
- easydel.utils.__init__.compile_function(func, func_input_args, func_input_kwargs, mesh=None, in_shardings=None, out_shardings=None, static_argnums=None, donate_argnums=None)[source]#
Compiles a JAX function with optional sharding and mesh configuration.
- Parameters
func – The JAX function to compile.
func_input_args – Input arguments for the function.
func_input_kwargs – Input keyword arguments for the function.
mesh – tp.Optional JAX mesh for distributed execution.
in_shardings – tp.Optional input sharding specifications.
out_shardings – tp.Optional output sharding specifications.
static_argnums – Indices of static arguments.
donate_argnums – Indices of arguments to donate.
- Returns
Compiled JAX function.
- easydel.utils.__init__.get_logger(name: str, level: Optional[int] = None) LazyLogger[source]#
Function to create a lazy logger that only initializes when first used.
- Parameters
name (str) – The name of the logger.
level (Optional[int]) – The logging level. Defaults to environment variable LOGGING_LEVEL_ED or “INFO”.
- Returns
A lazy logger instance that initializes on first use.
- Return type
- easydel.utils.__init__.is_package_available(package_name: str) bool[source]#
Checks if a package is available in the current Python environment.
- Parameters
package_name – The name of the package to check (e.g., “numpy”).
- Returns
True if the package is available, False otherwise.