easydel.infra.base_module#
- class easydel.infra.base_module.EasyDeLBaseModule(*args: Any, **kwargs: Any)[source]#
Bases:
Module,BaseModuleProtocol,EasyBridgeMixin,EasyGenerationMixinBase class for EasyDeL modules, providing common functionalities for model initialization, parameter handling, and integration with the EasyDeL ecosystem.
- apply_lora_to_layers(lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = False, rngs: Optional[Rngs] = None) SELF[source]#
Applies LoRA (Low-Rank Adaptation) to specified linear layers within a model.
- compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#
basic compute_loss call
- fully_shard(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None) SELF[source]#
- gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#
Gathers the model’s parameters based on the specified partitioning rules and mesh.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules for gathering.
mesh (jax.sharding.Mesh, optional) – The mesh to gather from. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model.
- Returns
The gathered model.
- Return type
- property graphdef: Union[NodeDef[Node], NodeRef[Node]]#
- property graphother: State[Key, VariableState[Any]]#
- property graphstate: State[Key, VariableState[Any]]#
- property graphtree_params_shape: Dict#
Evaluates the shape of the model’s parameters and returns a dictionary.
- property graphtree_shape: Dict#
Evaluates the shape of the modeland returns a dictionary.
- classmethod lazy_init(*args, **kwargs) SELF[source]#
initialize the base class with nnx.eval_shape carefully
- property loss_function#
- merge_lora_params(pytree: Dict) SELF[source]#
Merge Given Pytree (LoRA Params) with current LoRA Module.
- merge_params_dict(params_dict: Dict) SELF[source]#
Merges the model parameters from a dictionary into the current model.
- Parameters
params_dict (tp.Dict) – A dictionary containing the parameters to merge.
- Returns
The model with merged parameters.
- Return type
- property model_task: Optional[str]#
Returns the model task.
- property model_type: Optional[str]#
Returns the model type.
- property parameters: Dict#
- property params: Dict#
- property params_sharding: Dict#
return the sharding of the model parameters
- property pure_transform_fn#
generates a pure transform function for converting torch to easydel module.
- quantize(method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT, block_size: int = 128, quantization_pattern: Optional[str] = None, quantize_tensors: bool = True, verbose: Optional[bool] = None) SELF[source]#
Quantizes the model’s linear layers.
- Parameters
method (EasyDeLQuantizationMethods, optional) – The quantization method to use.
block_size (int, optional) – The block size for quantization.
quantization_pattern (str, optional) – The quantization pattern to use. quantize_tensors (bool): whenever to quantize tensors or quantize Linear Layers.` verbose (bool, optional): Verbose quantizing process
- Returns
The quantized model.
- Return type
- shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#
Shards the model’s parameters using the specified partitioning rules and mesh.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules for sharding.
mesh (jax.sharding.Mesh, optional) – The mesh to shard across. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model.
- Returns
The sharded model.
- Return type
- split_params_dict(extract_fn: Optional[Callable] = None, remove_none: bool = True) Dict[source]#
Splits the model parameters and returns them as a dictionary, removing VariableState from the tree.
- Parameters
extract_fn (tp.Optional[tp.Callable], optional) – Function to extract values from the parameters.
remove_none (bool, optional) – Whether to remove None values from the dictionary.
- Returns
The dictionary of split parameters.
- Return type
tp.Dict
- property static_arguments: Tuple#
- property transform_fn#
generate transform function for converting torch to easydel module.