easydel.infra.utils#

class easydel.infra.utils.CompilationTracker[source]#

Bases: object

property online_flops#
trace_compilation()[source]#
class easydel.infra.utils.FunctionTracer[source]#

Bases: object

class easydel.infra.utils.ModuleCaches(value: Union[A, VariableMetadata[A]], **metadata: Any)[source]#

Bases: Cache

raw_value: A#
class easydel.infra.utils.OverWriteWithGradient(value: Union[A, VariableMetadata[A]], **metadata: Any)[source]#

Bases: Param

raw_value: A#
class easydel.infra.utils.TraceResult(executable)[source]#

Bases: object

property cost_analysis#
property flops#
easydel.infra.utils.add_start_docstrings(*docstr)[source]#

The add_start_docstrings function is a decorator that adds the docstrings to the beginning of a function. The add_start_docstrings function takes in an arbitrary number of strings and returns a decorator. The returned decorator takes in one argument, fn, which is assumed to be a function. The docstring for fn is set equal to the concatenation of all the strings passed into add_start_docstrings plus (if it exists) the original docstring for fn.

Parameters

*docstr – Pass in a variable number of arguments to the function

Returns

A decorator that adds the docstrings to the function

easydel.infra.utils.apply_lora_to_layers(model: Module, /, *, lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = True, rngs: Optional[Rngs] = None) Module[source]#

Applies LoRA (Low-Rank Adaptation) to specified linear layers within a model.

Parameters
  • model – The EasyDeL model to modify.

  • lora_rank – The rank of the LoRA adapters.

  • lora_pattern – A regular expression pattern to match the names of modules to which LoRA should be applied. Defaults to “.*” (all linear layers).

  • verbose – Whether to display a progress bar.

  • rngs – A flax.nnx.Rngs instance for random number generation. If None, initializes with a seed of 0.

Returns

The modified model with LoRA applied to the specified layers.

easydel.infra.utils.apply_sparsity_to_params(params: Union[Dict[str, Any], Any], sparsify_module: Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo', verbose: bool = True) Union[Dict[str, Any], Any][source]#
easydel.infra.utils.auto_remat(*modules: Type[M], policy: Union[EasyDeLGradientCheckPointers, str] = EasyDeLGradientCheckPointers.NONE, prevent_cse: bool = True) Tuple[Type[M], ...][source]#
easydel.infra.utils.block_wise_ffn(remat_ffn, inputs, chunk_size: int)[source]#
easydel.infra.utils.canonicalize_dtype(*args, dtype: Optional[dtype] = None, inexact: bool = True) dtype[source]#

Canonicalize an optional dtype to the definitive dtype.

If the dtype is None this function will infer the dtype. If it is not None it will be returned unmodified or an exceptions is raised if the dtype is invalid. from the input arguments using jnp.result_type.

Parameters
  • *args – JAX array compatible values. None values are ignored.

  • dtype – tp.Optional dtype override. If specified the arguments are cast to the specified dtype instead and dtype inference is disabled.

  • inexact – When True, the output dtype must be a subdtype

  • This (of jnp.inexact. Inexact dtypes are real or complex floating points.) –

  • on (is useful when you want to apply operations that don'position_ids work directly) –

  • example. (integers like taking a mean for) –

Returns

The dtype that *args should be cast to.

easydel.infra.utils.control_mlp_sharding(x: Array, partition_axis: PartitionAxis)[source]#

handles MLP Shardings

easydel.infra.utils.count_flop_jaxpr(jaxpr) int[source]#

Count flops in a Jaxpr.

easydel.infra.utils.extract_static_parameters(module)[source]#

Extract static_argnums for specified parameters across functions in a module.

Parameters

module (types.ModuleType) – The module to inspect

Returns

A dictionary mapping function names to their static parameter indices

Return type

dict

easydel.infra.utils.get_dot_general_by_bits(bits: Optional[int] = None, mode: Literal['train', 'serve', 'convert'] = 'train') dict[source]#

The get_general_dot function is a helper function that returns a q_flax.QDotGeneral object with the specified number of bits for forward and backward passes. If no bits are specified, the function returns None.

Parameters
  • bits – tp.Optional[int]: Specify the number of bits for quantization

  • mode – EasyMethod: Specify the use of model to init the QDot Method for (e.q TRAIN,SERVE,…)

Returns

A dict that contain dot_general_cls

easydel.infra.utils.get_gradient_checkpoint_policy(name)[source]#

The get_gradient_checkpoint_policy function is a helper function that returns the gradient checkpoint policy specified by the name parameter.

easydel.infra.utils.is_flatten(pytree: dict)[source]#
The is_flatten function checks if the pytree is flattened.

If it is, then the first key in the dictionary will be a tuple of (mpl, mpl_id). Otherwise, it will be an integer representing mpl_id.

Parameters

pytree – dict: Pass the pytree to the function

Returns

True if the pytree is a flattened tree, and false otherwise

easydel.infra.utils.merge_lora_params(model: Module, lora_tree: Dict) Module[source]#

get LoRA (Low-Rank Adaptation) from layers within a model.

Parameters

model – The EasyDeL model.

Returns

LoRA Layer Weights.

easydel.infra.utils.quantize_linear_layers(model: Module, /, *, method: Optional[EasyDeLQuantizationMethods] = None, block_size: int = 256, quantization_pattern: Optional[str] = None, verbose: bool = True) Module[source]#

Quantize parameters to requested precision, excluding specified layers.

Parameters
  • model – The model to quantize.

  • method (EasyDeLQuantizationMethods) – quantization method for params.

  • quantization_pattern (str) – re pattern for layers to be quantized.

  • verbose (bool) – whenever to use tqdm for logging stuff.

Returns

Quantized parameters in the same structure as the input.

easydel.infra.utils.quick_gelu(x)[source]#
easydel.infra.utils.split_lora_params(model: Module) Module[source]#

get LoRA (Low-Rank Adaptation) from layers within a model.

Parameters

model – The EasyDeL model.

Returns

LoRA Layer Weights.

easydel.infra.utils.trace_functions()[source]#
easydel.infra.utils.unwrap_lora_to_layers(model: Module, /, *, verbose: bool = True) Module[source]#

UnWrap LoRA (Low-Rank Adaptation) from specified linear layers within a model.