easydel.utils.__init__#

class easydel.utils.__init__.DataClassArgumentParser(dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs: Any)[source]#

Bases: ArgumentParser

A subclass of argparse.ArgumentParser that automatically generates arguments based on dataclass type hints.

It supports additional argparse features (like sub-groups) and can also load configuration from dictionaries, JSON files, or YAML files.

dataclass_types: Iterable[DataClassType]#
parse_args_into_dataclasses(args: Optional[List[str]] = None, return_remaining_strings: bool = False, look_for_args_file: bool = True, args_filename: Optional[str] = None, args_file_flag: Optional[str] = None) Tuple[Any, ...][source]#

Parse command-line arguments into instances of the specified dataclass types.

Optionally, this method can also look for an external โ€œ.argsโ€ file or a command-line flag that points to one, and prepend its content to the command-line arguments.

Raises

ValueError โ€“ If there are any unknown arguments (and return_remaining_strings is False).

parse_dict(args: Dict[str, Any], allow_extra_keys: bool = False) Tuple[Any, ...][source]#

Parse a dictionary of configuration values into dataclass instances.

Parameters
  • args โ€“ Dictionary containing configuration values.

  • allow_extra_keys โ€“ If False, raises an exception if unknown keys are present.

parse_json_file(json_file: Union[str, PathLike], allow_extra_keys: bool = False) Tuple[Any, ...][source]#

Load a JSON file and parse it into dataclass instances.

parse_yaml_file(yaml_file: Union[str, PathLike], allow_extra_keys: bool = False) Tuple[Any, ...][source]#

Load a YAML file and parse it into dataclass instances.

class easydel.utils.__init__.EasyQuantizer(quantization_method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NF4, quantization_platform: Optional[EasyDeLPlatforms] = EasyDeLPlatforms.JAX, quantization_pattern: Optional[str] = None, block_size: int = 256, **kwargs)[source]#

Bases: object

quantize_linears(model: Module, /, *, quantization_pattern: Optional[str] = None, verbose: bool = True) Module[source]#

Quantize parameters to requested precision, excluding specified layers.

Parameters
  • model โ€“ The model to quantize.

  • quantization_pattern (str) โ€“ re pattern for layers to be quantized.

  • verbose (bool) โ€“ whenever to use tqdm for logging stuff.

Returns

Quantized parameters in the same structure as the input.

class easydel.utils.__init__.GenerateRNG(seed: int = 0)[source]#

Bases: object

An infinite generator of JAX PRNGKeys, useful for iterating over seeds.

property rng: PRNGKey#

Provides access to the next PRNGKey without advancing the generator.

Returns

The next PRNGKey in the sequence.

class easydel.utils.__init__.JaxRNG(rng: PRNGKey)[source]#

Bases: object

A wrapper around JAXโ€™s PRNGKey that simplifies key splitting.

classmethod from_seed(seed: int) JaxRNG[source]#

Creates a JaxRNG instance from a seed.

Parameters

seed โ€“ The seed to use for the random number generator.

Returns

A JaxRNG instance.

class easydel.utils.__init__.Timer(name)[source]#

Bases: object

elapsed_time(reset=True)[source]#
reset()[source]#
start()[source]#
stop()[source]#
class easydel.utils.__init__.Timers(use_wandb, tensorboard_writer: Any)[source]#

Bases: object

log(names, normalizer=1.0, reset=True)[source]#
timed(name, log=True, reset=True)[source]#
write(names, iteration, normalizer=1.0, reset=False)[source]#
easydel.utils.__init__.cache_compiles(tag: Optional[str] = None, static_argnames: Optional[List[str]] = None)[source]#
easydel.utils.__init__.cjit(fn: Callable, static_argnums: Optional[Tuple[str]] = None, static_argnames: Optional[Tuple[int]] = None)[source]#
easydel.utils.__init__.compile_function(func, func_input_args, func_input_kwargs, mesh=None, in_shardings=None, out_shardings=None, static_argnums=None, donate_argnums=None)[source]#

Compiles a JAX function with optional sharding and mesh configuration.

Parameters
  • func โ€“ The JAX function to compile.

  • func_input_args โ€“ Input arguments for the function.

  • func_input_kwargs โ€“ Input keyword arguments for the function.

  • mesh โ€“ tp.Optional JAX mesh for distributed execution.

  • in_shardings โ€“ tp.Optional input sharding specifications.

  • out_shardings โ€“ tp.Optional output sharding specifications.

  • static_argnums โ€“ Indices of static arguments.

  • donate_argnums โ€“ Indices of arguments to donate.

Returns

Compiled JAX function.

easydel.utils.__init__.get_logger(name: str, level: Optional[int] = None) LazyLogger[source]#

Function to create a lazy logger that only initializes when first used.

Parameters
  • name (str) โ€“ The name of the logger.

  • level (Optional[int]) โ€“ The logging level. Defaults to environment variable LOGGING_LEVEL_ED or โ€œINFOโ€.

Returns

A lazy logger instance that initializes on first use.

Return type

LazyLogger

easydel.utils.__init__.load_compiled_fn(path: Union[str, PathLike], prefix: Optional[str] = None)[source]#
easydel.utils.__init__.save_compiled_fn(path: Union[str, PathLike], fn: Any, prefix: Optional[str] = None)[source]#