easydel.infra.utils#
- class easydel.infra.utils.ModuleCaches(value: Union[A, VariableMetadata[A]], **metadata: Any)[source]#
Bases:
Cache- raw_value: A#
- class easydel.infra.utils.OverWriteWithGradient(value: Union[A, VariableMetadata[A]], **metadata: Any)[source]#
Bases:
Param- raw_value: A#
- class easydel.infra.utils.TraceResult(executable)[source]#
Bases:
object- property cost_analysis#
- property flops#
- easydel.infra.utils.add_start_docstrings(*docstr)[source]#
The add_start_docstrings function is a decorator that adds the docstrings to the beginning of a function. The add_start_docstrings function takes in an arbitrary number of strings and returns a decorator. The returned decorator takes in one argument, fn, which is assumed to be a function. The docstring for fn is set equal to the concatenation of all the strings passed into add_start_docstrings plus (if it exists) the original docstring for fn.
- Parameters
*docstr – Pass in a variable number of arguments to the function
- Returns
A decorator that adds the docstrings to the function
- easydel.infra.utils.apply_lora_to_layers(model: Module, /, *, lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = True, rngs: Optional[Rngs] = None) Module[source]#
Applies LoRA (Low-Rank Adaptation) to specified linear layers within a model.
- Parameters
model – The EasyDeL model to modify.
lora_rank – The rank of the LoRA adapters.
lora_pattern – A regular expression pattern to match the names of modules to which LoRA should be applied. Defaults to “.*” (all linear layers).
verbose – Whether to display a progress bar.
rngs – A flax.nnx.Rngs instance for random number generation. If None, initializes with a seed of 0.
- Returns
The modified model with LoRA applied to the specified layers.
- easydel.infra.utils.apply_sparsity_to_params(params: Union[Dict[str, Any], Any], sparsify_module: Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo', verbose: bool = True) Union[Dict[str, Any], Any][source]#
- easydel.infra.utils.auto_remat(*modules: Type[M], policy: Union[EasyDeLGradientCheckPointers, str] = EasyDeLGradientCheckPointers.NONE, prevent_cse: bool = True) Tuple[Type[M], ...][source]#
- easydel.infra.utils.canonicalize_dtype(*args, dtype: Optional[dtype] = None, inexact: bool = True) dtype[source]#
Canonicalize an optional dtype to the definitive dtype.
If the
dtypeis None this function will infer the dtype. If it is not None it will be returned unmodified or an exceptions is raised if the dtype is invalid. from the input arguments usingjnp.result_type.- Parameters
*args – JAX array compatible values. None values are ignored.
dtype – tp.Optional dtype override. If specified the arguments are cast to the specified dtype instead and dtype inference is disabled.
inexact – When True, the output dtype must be a subdtype
This (of jnp.inexact. Inexact dtypes are real or complex floating points.) –
on (is useful when you want to apply operations that don'position_ids work directly) –
example. (integers like taking a mean for) –
- Returns
The dtype that *args should be cast to.
- easydel.infra.utils.control_mlp_sharding(x: Array, partition_axis: PartitionAxis)[source]#
handles MLP Shardings
- easydel.infra.utils.extract_static_parameters(module)[source]#
Extract static_argnums for specified parameters across functions in a module.
- Parameters
module (types.ModuleType) – The module to inspect
- Returns
A dictionary mapping function names to their static parameter indices
- Return type
dict
- easydel.infra.utils.get_dot_general_by_bits(bits: Optional[int] = None, mode: Literal['train', 'serve', 'convert'] = 'train') dict[source]#
The get_general_dot function is a helper function that returns a q_flax.QDotGeneral object with the specified number of bits for forward and backward passes. If no bits are specified, the function returns None.
- Parameters
bits – tp.Optional[int]: Specify the number of bits for quantization
mode – EasyMethod: Specify the use of model to init the QDot Method for (e.q TRAIN,SERVE,…)
- Returns
A dict that contain dot_general_cls
- easydel.infra.utils.get_gradient_checkpoint_policy(name)[source]#
The get_gradient_checkpoint_policy function is a helper function that returns the gradient checkpoint policy specified by the name parameter.
- easydel.infra.utils.is_flatten(pytree: dict)[source]#
- The is_flatten function checks if the pytree is flattened.
If it is, then the first key in the dictionary will be a tuple of (mpl, mpl_id). Otherwise, it will be an integer representing mpl_id.
- Parameters
pytree – dict: Pass the pytree to the function
- Returns
True if the pytree is a flattened tree, and false otherwise
- easydel.infra.utils.merge_lora_params(model: Module, lora_tree: Dict) Module[source]#
get LoRA (Low-Rank Adaptation) from layers within a model.
- Parameters
model – The EasyDeL model.
- Returns
LoRA Layer Weights.
- easydel.infra.utils.quantize_linear_layers(model: Module, /, *, method: Optional[EasyDeLQuantizationMethods] = None, block_size: int = 256, quantization_pattern: Optional[str] = None, verbose: bool = True) Module[source]#
Quantize parameters to requested precision, excluding specified layers.
- Parameters
model – The model to quantize.
method (EasyDeLQuantizationMethods) – quantization method for params.
quantization_pattern (str) – re pattern for layers to be quantized.
verbose (bool) – whenever to use tqdm for logging stuff.
- Returns
Quantized parameters in the same structure as the input.