easydel.infra.base_module#

class easydel.infra.base_module.EasyDeLBaseModule(*args: Any, **kwargs: Any)[source]#

Bases: Module, BaseModuleProtocol, EasyBridgeMixin, EasyGenerationMixin

Base class for EasyDeL modules, providing common functionalities for model initialization, parameter handling, and integration with the EasyDeL ecosystem.

apply_lora_to_layers(lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = False, rngs: Optional[Rngs] = None) SELF[source]#

Applies LoRA (Low-Rank Adaptation) to specified linear layers within a model.

property causal_mask: Array#

Returns a causal mask from the config.

compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#

basic compute_loss call

float(change_runtime_dtype: bool = True) SELF[source]#

Converts Model paramters to float32.

property frequencies: Array#

Returns frequency values from the config.

fully_gather() SELF[source]#
fully_shard(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None) SELF[source]#
gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#

Gathers the model’s parameters based on the specified partitioning rules and mesh.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules for gathering.

  • mesh (jax.sharding.Mesh, optional) – The mesh to gather from. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model.

Returns

The gathered model.

Return type

EasyDeLBaseModule

get_static_arguments() Tuple[source]#

return static arguments kwargs for jax.jit

property graphdef: Union[NodeDef[Node], NodeRef[Node]]#
property graphother: State[Key, VariableState[Any]]#
property graphstate: State[Key, VariableState[Any]]#
property graphtree_params_shape: Dict#

Evaluates the shape of the model’s parameters and returns a dictionary.

property graphtree_shape: Dict#

Evaluates the shape of the modeland returns a dictionary.

half(change_runtime_dtype: bool = True) SELF[source]#

Converts Model paramters to float16.

classmethod lazy_init(*args, **kwargs) SELF[source]#

initialize the base class with nnx.eval_shape carefully

property loss_function#
merge_lora_params(pytree: Dict) SELF[source]#

Merge Given Pytree (LoRA Params) with current LoRA Module.

merge_params(tree)[source]#

merge state to the current model

merge_params_dict(params_dict: Dict) SELF[source]#

Merges the model parameters from a dictionary into the current model.

Parameters

params_dict (tp.Dict) – A dictionary containing the parameters to merge.

Returns

The model with merged parameters.

Return type

EasyDeLBaseModule

property mesh: Mesh#

Returns the mesh from the config.

property model_task: Optional[str]#

Returns the model task.

property model_type: Optional[str]#

Returns the model type.

property module_dtype: dtype#
property parameters: Dict#
property params: Dict#
property params_sharding: Dict#

return the sharding of the model parameters

prepare_inputs_for_call(**kwargs)[source]#

update inputs for calling model

property pure_transform_fn#

generates a pure transform function for converting torch to easydel module.

quantize(method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT, block_size: int = 128, quantization_pattern: Optional[str] = None, quantize_tensors: bool = True, verbose: Optional[bool] = None) SELF[source]#

Quantizes the model’s linear layers.

Parameters
  • method (EasyDeLQuantizationMethods, optional) – The quantization method to use.

  • block_size (int, optional) – The block size for quantization.

  • quantization_pattern (str, optional) – The quantization pattern to use. quantize_tensors (bool): whenever to quantize tensors or quantize Linear Layers.` verbose (bool, optional): Verbose quantizing process

Returns

The quantized model.

Return type

EasyDeLBaseModule

shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#

Shards the model’s parameters using the specified partitioning rules and mesh.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules for sharding.

  • mesh (jax.sharding.Mesh, optional) – The mesh to shard across. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model.

Returns

The sharded model.

Return type

EasyDeLBaseModule

split_lora_params() Dict[source]#

split Given Module (LoRA Module) and return LoRA Params.

split_params()[source]#

split the model parameters

split_params_dict(extract_fn: Optional[Callable] = None, remove_none: bool = True) Dict[source]#

Splits the model parameters and returns them as a dictionary, removing VariableState from the tree.

Parameters
  • extract_fn (tp.Optional[tp.Callable], optional) – Function to extract values from the parameters.

  • remove_none (bool, optional) – Whether to remove None values from the dictionary.

Returns

The dictionary of split parameters.

Return type

tp.Dict

property static_arguments: Tuple#
to_dtype(dtype: dtype) SELF[source]#

Applies sharding functions to the model’s state.

to_state() Any[source]#

converts current model to a EasyDeLState

to_torch(**kwargs)[source]#

converts current model to a huggingface torch model

property transform_fn#

generate transform function for converting torch to easydel module.

unwrap_lora_to_layers(verbose: bool = False) SELF[source]#

UnWrap LoRA (Low-Rank Adaptation) from specified linear layers within a model.