easydel.infra.base_module#
- class easydel.infra.base_module.EasyDeLBaseModule(*args: Any, **kwargs: Any)[source]#
Bases:
Module,BaseModuleProtocol,EasyBridgeMixin,EasyGenerationMixinBase class for EasyDeL modules, providing common functionalities for model initialization, parameter handling, and integration with the EasyDeL ecosystem.
- apply_lora_to_layers(lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = False, rngs: Optional[Rngs] = None) SELF[source]#
Applies Low-Rank Adaptation (LoRA) layers to the specified linear layers within the module.
Replaces targeted flax.linen.Dense layers with easydel.layers.lora.LoraLinear layers, initializing the LoRA matrices (A and B).
- Parameters
lora_rank (int) – The rank of the LoRA decomposition.
lora_pattern (tp.Optional[str], optional) – A regular expression to match the names of the Dense layers to apply LoRA to. If None, applies to common attention and MLP layers. Defaults to None.
verbose (bool, optional) – If True, prints information about which layers are being modified. Defaults to False.
rngs (tp.Optional[nn.Rngs], optional) – JAX random number generators for initializing LoRA matrices. If None, default RNGs might be used. Defaults to None.
- Returns
The module instance with LoRA layers applied.
- Return type
SELF
- property causal_mask: Array#
Retrieves or computes the basic causal attention mask from the configuration.
Uses self.config.get_basic_causal_mask() and caches the result.
- Returns
The causal attention mask, potentially cached.
- Return type
jnp.ndarray
- compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#
Computes the loss for the model given a batch of inputs and labels.
This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.
- Parameters
labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.
loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.
loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.
**batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).
- Returns
- A tuple containing:
The model’s output ( Pytree typically including logits, hidden states etc.)
A LossMetrics object containing the calculated loss and potentially other metrics.
- Return type
tp.Tuple[tp.Any, LossMetrics]
- Raises
AssertionError – If labels are required for the loss function but are not provided or inferred.
AssertionError – If sequence classification loss is used without num_labels in the config.
- float(change_runtime_dtype: bool = True) SELF[source]#
Converts the module’s parameters to single-precision (float32).
Optionally also changes the runtime computation dtype (self.dtype) to float32.
- Parameters
change_runtime_dtype (bool) – If True, also sets self.dtype to jnp.float32. Defaults to True.
- Returns
The module instance with parameters (and potentially runtime dtype) set to float32.
- Return type
SELF
- property frequencies: Array#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- fully_gather() SELF[source]#
Applies JAX sharding constraints to gather all parameters onto the host or a single device.
This function marks all parameters to have no sharding (PartitionSpec()). It uses jax.jit with out_shardings to enforce these gathering constraints.
- Returns
The model instance with gathering constraints applied.
- Return type
SELF
- fully_shard(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None) SELF[source]#
Applies JAX sharding constraints to all parameters based on the partition rules.
This function ensures that parameters are explicitly marked with their intended sharding, which can be useful for performance and correctness checks. It uses jax.jit with out_shardings to enforce the constraints.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses config rules. Defaults to None.
- Returns
The model instance with sharding constraints applied.
- Return type
SELF
- gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#
Gathers the model’s parameters from potentially distributed devices to the host or a single device.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules used to determine how parameters were originally sharded. If None, uses config rules. Defaults to None.
mesh (tp.Optional[Mesh], optional) – JAX device mesh from which to gather. If None, uses config mesh. Defaults to None.
overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional) – Additional functions to apply, potentially overriding default gathering for specific parameters. Defaults to None.
- Returns
The model instance with gathered parameters.
- Return type
SELF
- get_static_arguments() Tuple[source]#
Returns a tuple of static arguments required by the module’s __call__ method.
Static arguments are those that don’t change across calls and can be potentially cached or handled differently by JIT compilation. This base implementation returns an empty tuple. Subclasses should override this if they have static arguments.
- Returns
A tuple containing static arguments.
- Return type
tp.Tuple
- property graphdef: Union[NodeDef[Node], NodeRef[Node]]#
Returns the graph definition (structure without parameters) of the module.
Uses flax.nnx.split to separate the graph definition from the state (parameters).
- Returns
The graph definition of the module.
- Return type
nn.GraphDef
- property graphother: State[Key, VariableState[Any]]#
Returns any other state variables in the module (non-parameters).
Uses flax.nnx.split to separate non-parameter state variables.
- Returns
The graph state containing non-parameter variables.
- Return type
nn.GraphState
- property graphstate: State[Key, VariableState[Any]]#
Returns the graph state (parameters) of the module.
Uses flax.nnx.split to separate the state (parameters) from the graph definition.
- Returns
The graph state containing the module’s parameters.
- Return type
nn.GraphState
- property graphtree_params_shape: Dict#
Computes and returns the shapes of the module’s parameters as a nested dictionary.
It uses nnx.eval_shape to determine the shapes without actual computation, then extracts the shape information from the resulting graph state.
- Returns
A nested dictionary mirroring the parameter structure, containing their shapes.
- Return type
tp.Dict
- property graphtree_shape: Dict#
Computes and returns the shapes of all state variables (including non-parameters) in the module.
Uses nnx.eval_shape on the entire module state (parameters and others) and extracts the shape information.
- Returns
A nested dictionary mirroring the module’s state structure, containing the shapes.
- Return type
tp.Dict
- half(change_runtime_dtype: bool = True) SELF[source]#
Converts the module’s parameters to half-precision (float16).
Optionally also changes the runtime computation dtype (self.dtype) to float16.
- Parameters
change_runtime_dtype (bool) – If True, also sets self.dtype to jnp.float16. Defaults to True.
- Returns
The module instance with parameters (and potentially runtime dtype) set to float16.
- Return type
SELF
- property inv_frequencies: Array#
Retrieves or computes the inv-frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_inv_frequencies() and caches the result.
- Returns
The inv-frequency components, potentially cached.
- Return type
jnp.ndarray
- classmethod lazy_init(*args, **kwargs) SELF[source]#
Performs a “lazy” initialization using nnx.eval_shape.
This initializes the module structure and determines parameter shapes without actually allocating memory for the parameters. Useful for inspecting the model structure or preparing for sharding.
- Parameters
*args – Positional arguments passed to the class constructor.
**kwargs – Keyword arguments passed to the class constructor.
- Returns
A module instance with initialized structure but potentially abstract parameters.
- Return type
SELF
- property loss_function#
Determines and returns the appropriate loss function based on the configuration or model type.
It prioritizes config.loss_type, then self.loss_type, and finally tries to infer the loss type from the class name. If no suitable loss function is found, it defaults to ForCausalLMLoss and issues a warning.
- Returns
The selected loss function (e.g., ForCausalLMLoss, ForSequenceClassificationLoss).
- Return type
tp.Callable
- merge_lora_params(pytree: Dict) SELF[source]#
Merges LoRA parameters from a pytree into the base model’s parameters.
- Parameters
pytree (tp.Dict) – A dictionary (pytree) containing the LoRA parameters (A and B matrices) structured similarly to the base model’s parameters.
- Returns
The module instance with LoRA parameters merged into the base weights.
- Return type
SELF
- static merge_module(graphdef: Union[NodeDef[Node], NodeRef[Node]], graphstate: State[Key, VariableState[Any]], graphother: State[Key, VariableState[Any]])[source]#
- merge_params(tree)[source]#
Merges a given parameter state tree back into the module.
Reconstructs the module using its existing graph definition and ‘other’ state, but replaces the parameter state with the provided tree.
- Parameters
tree – A pytree (likely a nn.GraphState) containing the parameters to merge.
- Returns
The module instance with the new parameters merged in.
- Return type
- merge_params_dict(params_dict: Dict) SELF[source]#
Merges parameters from a dictionary back into the module’s state.
Updates the module’s current parameter state with values from the provided dictionary.
- Parameters
params_dict (tp.Dict) – A nested dictionary containing the parameters to merge. The structure should match the module’s parameter structure.
- Returns
The module instance with the parameters from the dictionary merged in.
- Return type
SELF
- Raises
KeyError – If a key from params_dict is not found in the module’s current state.
- property mesh: Mesh#
Retrieves the JAX device mesh from the module’s configuration.
- Returns
The device mesh defined in self.config.mesh.
- Return type
- property model_task: Optional[str]#
Returns the specific task associated with this model instance (e.g., ‘causal-language-model’).
- Returns
The model task identifier, or None if not set.
- Return type
tp.Optional[str]
- property model_type: Optional[str]#
Returns the specific type of this model instance (e.g., ‘llama’, ‘mistral’).
- Returns
The model type identifier, or None if not set.
- Return type
tp.Optional[str]
- property module_dtype: dtype#
Determines the data type of the module’s parameters.
It inspects the flattened parameter state to find the dtype of the first parameter encountered.
- Returns
The data type of the module’s parameters.
- Return type
jnp.dtype
- property parameters: Dict#
Retrieves the parameters of the module as a dictionary.
This property iterates through the module and its submodules, extracting variables marked as nn.Param and returning them in a flat dictionary where keys represent the parameter path.
- Returns
A dictionary containing the module’s parameters.
- Return type
tp.Dict
- property params: Dict#
Returns the parameters and other state variables of the module as a dictionary.
Uses flax.nnx.split to get the combined state (parameters and others).
- Returns
A dictionary containing all state variables of the module.
- Return type
tp.Dict
- property params_sharding: Dict#
Retrieves the sharding annotation for each parameter in the module.
- Returns
- A nested dictionary mirroring the parameter structure, containing the
sharding information (e.g., NamedSharding, PartitionSpec) for each parameter, or None if unsharded.
- Return type
tp.Dict
- prepare_inputs_for_call(**kwargs)[source]#
Prepares keyword arguments before passing them to the module’s __call__ method.
This base implementation simply returns the kwargs as is. Subclasses can override this to modify or add arguments as needed (e.g., for generation).
- Parameters
**kwargs – The keyword arguments intended for __call__.
- Returns
The prepared keyword arguments.
- Return type
dict
- property pure_transform_fn#
Returns a pure transformation function for PyTorch state dicts to EasyDeL parameters.
Similar to transform_fn, but this version does not include sharding functions. It identifies embedding and LayerNorm layers and returns a partial function (torch_dict_to_easydel_params) configured only with layer names and dtype.
- Returns
A partial function for converting a PyTorch state dict without applying sharding.
- Return type
tp.Callable
- quantize(method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT, block_size: int = 128, quantization_pattern: Optional[str] = None, quantize_tensors: bool = True, verbose: Optional[bool] = None) SELF[source]#
Applies quantization to the module’s linear layers or tensors.
- Parameters
method (EasyDeLQuantizationMethods, optional) – The quantization algorithm to use (e.g., A8BIT, NF4). Defaults to EasyDeLQuantizationMethods.A8BIT.
block_size (int, optional) – The block size for quantization methods that support it. Defaults to 128.
quantization_pattern (tp.Optional[str], optional) – A regular expression to match parameter names that should be quantized. If None, uses a default pattern. Defaults to None.
quantize_tensors (bool, optional) – If True, quantizes the tensor values directly. If False (currently default behavior in implementation), replaces Linear layers with their quantized equivalents. Defaults to True (though implementation differs).
verbose (tp.Optional[bool], optional) – If True, logs information during the quantization process. Defaults to True only on process index 0.
- Returns
The quantized model instance.
- Return type
SELF
- shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#
Shards the model’s parameters according to the specified rules and mesh.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses config rules. Defaults to None.
mesh (tp.Optional[Mesh], optional) – JAX device mesh. If None, uses config mesh. Defaults to None.
overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional) – Additional functions to apply, potentially overriding default sharding for specific parameters. Defaults to None.
- Returns
The sharded model instance.
- Return type
SELF
- split_lora_params() Dict[source]#
Splits merged LoRA parameters back out from the base model’s parameters.
This function assumes LoRA parameters were previously merged using merge_lora_params or a similar process that stored the original base weights and LoRA weights appropriately.
- Returns
- A pytree containing the extracted LoRA parameters (A and B matrices).
The base model parameters are restored to their original (pre-merge) state.
- Return type
tp.Dict
- split_params()[source]#
Splits the module and returns the parameter state.
Uses nnx.split to extract the GraphState containing the parameters.
- Returns
The parameter state of the module.
- Return type
nn.GraphState
- split_params_dict(extract_fn: Optional[Callable] = None, remove_none: bool = True) Dict[source]#
Splits the module parameters and returns them as a nested dictionary.
Extracts the parameter state, converts it to a plain dictionary (removing VariableState wrappers), and optionally removes entries with None values.
- Parameters
extract_fn (tp.Optional[tp.Callable], optional) – A function to apply to each parameter during extraction. Defaults to None.
remove_none (bool, optional) – If True, removes key-value pairs where the value is None. Defaults to True.
- Returns
A nested dictionary containing the module’s parameters.
- Return type
tp.Dict
- property static_arguments: Tuple#
Retrieves or computes static arguments needed for the module’s __call__ method.
Uses self.get_static_arguments() and caches the result. Static arguments are typically those that don’t change during execution and can be pre-computed.
- Returns
A tuple of static arguments.
- Return type
tp.Tuple
- to_dtype(dtype: dtype) SELF[source]#
Converts the module’s parameters to the specified data type.
It iterates through the module’s parameters (excluding quantization-related ones) and casts them to the target dtype. It also updates the param_dtype attribute of the module and its submodules if they exist.
- Parameters
dtype (jnp.dtype) – The target data type for the parameters.
- Returns
The module instance with parameters converted to the specified dtype.
- Return type
SELF
- to_state() Any[source]#
Converts the current module instance into an EasyDeLState object.
This is useful for saving and managing the model’s state, including parameters and potentially optimizer state (though optimizer state is typically added later).
- Returns
An EasyDeLState object representing the current model state.
- Return type
- to_torch(**kwargs)[source]#
Converts the EasyDeL module to its equivalent Hugging Face PyTorch model.
Requires the corresponding PyTorch model class to be available and registered. Uses utility functions to transfer parameters from JAX to PyTorch format.
- Parameters
**kwargs – Additional keyword arguments passed to the parameter transformation function.
- Returns
The equivalent Hugging Face PyTorch model with loaded weights.
- Return type
torch.nn.Module
- property transform_fn#
Returns a partial function for transforming PyTorch state dicts to EasyDeL parameters.
This function identifies embedding and LayerNorm layers within the module and creates a transformation function (torch_dict_to_easydel_params) pre-configured with these layer names, the target parameter dtype, and the module’s sharding functions.
- Returns
A partial function ready to convert a PyTorch state dict.
- Return type
tp.Callable
- unwrap_lora_to_layers(verbose: bool = False) SELF[source]#
Reverts the application of LoRA layers, restoring the original linear layers.
Replaces easydel.layers.lora.LoraLinear layers with their original flax.linen.Dense counterparts, discarding the LoRA matrices.
- Parameters
verbose (bool, optional) – If True, prints information about which layers are being reverted. Defaults to False.
- Returns
The module instance with LoRA layers removed and original layers restored.
- Return type
SELF