easydel.infra.base_module

Contents

easydel.infra.base_module#

class easydel.infra.base_module.EasyDeLBaseModule(*args: Any, **kwargs: Any)[source]#

Bases: Module, BaseModuleProtocol, EasyBridgeMixin, EasyGenerationMixin

Base class for EasyDeL modules, providing common functionalities for model initialization, parameter handling, and integration with the EasyDeL ecosystem.

apply_lora_to_layers(lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = False, rngs: Optional[Rngs] = None) SELF[source]#

Applies Low-Rank Adaptation (LoRA) layers to the specified linear layers within the module.

Replaces targeted flax.linen.Dense layers with easydel.layers.lora.LoraLinear layers, initializing the LoRA matrices (A and B).

Parameters
  • lora_rank (int) – The rank of the LoRA decomposition.

  • lora_pattern (tp.Optional[str], optional) – A regular expression to match the names of the Dense layers to apply LoRA to. If None, applies to common attention and MLP layers. Defaults to None.

  • verbose (bool, optional) – If True, prints information about which layers are being modified. Defaults to False.

  • rngs (tp.Optional[nn.Rngs], optional) – JAX random number generators for initializing LoRA matrices. If None, default RNGs might be used. Defaults to None.

Returns

The module instance with LoRA layers applied.

Return type

SELF

apply_out_shardings(out_shardings)[source]#
property causal_mask: Array#

Retrieves or computes the basic causal attention mask from the configuration.

Uses self.config.get_basic_causal_mask() and caches the result.

Returns

The causal attention mask, potentially cached.

Return type

jnp.ndarray

compute_complex_rotary(position_ids: Array) Array[source]#
compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#

Computes the loss for the model given a batch of inputs and labels.

This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.

Parameters
  • labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.

  • loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.

  • loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.

  • **batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).

Returns

A tuple containing:
  • The model’s output ( Pytree typically including logits, hidden states etc.)

  • A LossMetrics object containing the calculated loss and potentially other metrics.

Return type

tp.Tuple[tp.Any, LossMetrics]

Raises
  • AssertionError – If labels are required for the loss function but are not provided or inferred.

  • AssertionError – If sequence classification loss is used without num_labels in the config.

float(change_runtime_dtype: bool = True) SELF[source]#

Converts the module’s parameters to single-precision (float32).

Optionally also changes the runtime computation dtype (self.dtype) to float32.

Parameters

change_runtime_dtype (bool) – If True, also sets self.dtype to jnp.float32. Defaults to True.

Returns

The module instance with parameters (and potentially runtime dtype) set to float32.

Return type

SELF

property frequencies: Array#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

fully_gather() SELF[source]#

Applies JAX sharding constraints to gather all parameters onto the host or a single device.

This function marks all parameters to have no sharding (PartitionSpec()). It uses jax.jit with out_shardings to enforce these gathering constraints.

Returns

The model instance with gathering constraints applied.

Return type

SELF

fully_shard(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None) SELF[source]#

Applies JAX sharding constraints to all parameters based on the partition rules.

This function ensures that parameters are explicitly marked with their intended sharding, which can be useful for performance and correctness checks. It uses jax.jit with out_shardings to enforce the constraints.

Parameters

partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses config rules. Defaults to None.

Returns

The model instance with sharding constraints applied.

Return type

SELF

gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#

Gathers the model’s parameters from potentially distributed devices to the host or a single device.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules used to determine how parameters were originally sharded. If None, uses config rules. Defaults to None.

  • mesh (tp.Optional[Mesh], optional) – JAX device mesh from which to gather. If None, uses config mesh. Defaults to None.

  • overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional) – Additional functions to apply, potentially overriding default gathering for specific parameters. Defaults to None.

Returns

The model instance with gathered parameters.

Return type

SELF

get_static_arguments() Tuple[source]#

Returns a tuple of static arguments required by the module’s __call__ method.

Static arguments are those that don’t change across calls and can be potentially cached or handled differently by JIT compilation. This base implementation returns an empty tuple. Subclasses should override this if they have static arguments.

Returns

A tuple containing static arguments.

Return type

tp.Tuple

property graphdef: Union[NodeDef[Node], NodeRef[Node]]#

Returns the graph definition (structure without parameters) of the module.

Uses flax.nnx.split to separate the graph definition from the state (parameters).

Returns

The graph definition of the module.

Return type

nn.GraphDef

property graphother: State[Key, VariableState[Any]]#

Returns any other state variables in the module (non-parameters).

Uses flax.nnx.split to separate non-parameter state variables.

Returns

The graph state containing non-parameter variables.

Return type

nn.GraphState

property graphstate: State[Key, VariableState[Any]]#

Returns the graph state (parameters) of the module.

Uses flax.nnx.split to separate the state (parameters) from the graph definition.

Returns

The graph state containing the module’s parameters.

Return type

nn.GraphState

property graphtree_params_shape: Dict#

Computes and returns the shapes of the module’s parameters as a nested dictionary.

It uses nnx.eval_shape to determine the shapes without actual computation, then extracts the shape information from the resulting graph state.

Returns

A nested dictionary mirroring the parameter structure, containing their shapes.

Return type

tp.Dict

property graphtree_shape: Dict#

Computes and returns the shapes of all state variables (including non-parameters) in the module.

Uses nnx.eval_shape on the entire module state (parameters and others) and extracts the shape information.

Returns

A nested dictionary mirroring the module’s state structure, containing the shapes.

Return type

tp.Dict

half(change_runtime_dtype: bool = True) SELF[source]#

Converts the module’s parameters to half-precision (float16).

Optionally also changes the runtime computation dtype (self.dtype) to float16.

Parameters

change_runtime_dtype (bool) – If True, also sets self.dtype to jnp.float16. Defaults to True.

Returns

The module instance with parameters (and potentially runtime dtype) set to float16.

Return type

SELF

property inv_frequencies: Array#

Retrieves or computes the inv-frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_inv_frequencies() and caches the result.

Returns

The inv-frequency components, potentially cached.

Return type

jnp.ndarray

classmethod lazy_init(*args, **kwargs) SELF[source]#

Performs a “lazy” initialization using nnx.eval_shape.

This initializes the module structure and determines parameter shapes without actually allocating memory for the parameters. Useful for inspecting the model structure or preparing for sharding.

Parameters
  • *args – Positional arguments passed to the class constructor.

  • **kwargs – Keyword arguments passed to the class constructor.

Returns

A module instance with initialized structure but potentially abstract parameters.

Return type

SELF

property loss_function#

Determines and returns the appropriate loss function based on the configuration or model type.

It prioritizes config.loss_type, then self.loss_type, and finally tries to infer the loss type from the class name. If no suitable loss function is found, it defaults to ForCausalLMLoss and issues a warning.

Returns

The selected loss function (e.g., ForCausalLMLoss, ForSequenceClassificationLoss).

Return type

tp.Callable

merge_lora_params(pytree: Dict) SELF[source]#

Merges LoRA parameters from a pytree into the base model’s parameters.

Parameters

pytree (tp.Dict) – A dictionary (pytree) containing the LoRA parameters (A and B matrices) structured similarly to the base model’s parameters.

Returns

The module instance with LoRA parameters merged into the base weights.

Return type

SELF

static merge_module(graphdef: Union[NodeDef[Node], NodeRef[Node]], graphstate: State[Key, VariableState[Any]], graphother: State[Key, VariableState[Any]])[source]#
merge_params(tree)[source]#

Merges a given parameter state tree back into the module.

Reconstructs the module using its existing graph definition and ‘other’ state, but replaces the parameter state with the provided tree.

Parameters

tree – A pytree (likely a nn.GraphState) containing the parameters to merge.

Returns

The module instance with the new parameters merged in.

Return type

EasyDeLBaseModule

merge_params_dict(params_dict: Dict) SELF[source]#

Merges parameters from a dictionary back into the module’s state.

Updates the module’s current parameter state with values from the provided dictionary.

Parameters

params_dict (tp.Dict) – A nested dictionary containing the parameters to merge. The structure should match the module’s parameter structure.

Returns

The module instance with the parameters from the dictionary merged in.

Return type

SELF

Raises

KeyError – If a key from params_dict is not found in the module’s current state.

property mesh: Mesh#

Retrieves the JAX device mesh from the module’s configuration.

Returns

The device mesh defined in self.config.mesh.

Return type

jax.sharding.Mesh

property model_task: Optional[str]#

Returns the specific task associated with this model instance (e.g., ‘causal-language-model’).

Returns

The model task identifier, or None if not set.

Return type

tp.Optional[str]

property model_type: Optional[str]#

Returns the specific type of this model instance (e.g., ‘llama’, ‘mistral’).

Returns

The model type identifier, or None if not set.

Return type

tp.Optional[str]

property module_dtype: dtype#

Determines the data type of the module’s parameters.

It inspects the flattened parameter state to find the dtype of the first parameter encountered.

Returns

The data type of the module’s parameters.

Return type

jnp.dtype

property parameters: Dict#

Retrieves the parameters of the module as a dictionary.

This property iterates through the module and its submodules, extracting variables marked as nn.Param and returning them in a flat dictionary where keys represent the parameter path.

Returns

A dictionary containing the module’s parameters.

Return type

tp.Dict

property params: Dict#

Returns the parameters and other state variables of the module as a dictionary.

Uses flax.nnx.split to get the combined state (parameters and others).

Returns

A dictionary containing all state variables of the module.

Return type

tp.Dict

property params_sharding: Dict#

Retrieves the sharding annotation for each parameter in the module.

Returns

A nested dictionary mirroring the parameter structure, containing the

sharding information (e.g., NamedSharding, PartitionSpec) for each parameter, or None if unsharded.

Return type

tp.Dict

prepare_inputs_for_call(**kwargs)[source]#

Prepares keyword arguments before passing them to the module’s __call__ method.

This base implementation simply returns the kwargs as is. Subclasses can override this to modify or add arguments as needed (e.g., for generation).

Parameters

**kwargs – The keyword arguments intended for __call__.

Returns

The prepared keyword arguments.

Return type

dict

property pure_transform_fn#

Returns a pure transformation function for PyTorch state dicts to EasyDeL parameters.

Similar to transform_fn, but this version does not include sharding functions. It identifies embedding and LayerNorm layers and returns a partial function (torch_dict_to_easydel_params) configured only with layer names and dtype.

Returns

A partial function for converting a PyTorch state dict without applying sharding.

Return type

tp.Callable

quantize(method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT, block_size: int = 128, quantization_pattern: Optional[str] = None, quantize_tensors: bool = True, verbose: Optional[bool] = None) SELF[source]#

Applies quantization to the module’s linear layers or tensors.

Parameters
  • method (EasyDeLQuantizationMethods, optional) – The quantization algorithm to use (e.g., A8BIT, NF4). Defaults to EasyDeLQuantizationMethods.A8BIT.

  • block_size (int, optional) – The block size for quantization methods that support it. Defaults to 128.

  • quantization_pattern (tp.Optional[str], optional) – A regular expression to match parameter names that should be quantized. If None, uses a default pattern. Defaults to None.

  • quantize_tensors (bool, optional) – If True, quantizes the tensor values directly. If False (currently default behavior in implementation), replaces Linear layers with their quantized equivalents. Defaults to True (though implementation differs).

  • verbose (tp.Optional[bool], optional) – If True, logs information during the quantization process. Defaults to True only on process index 0.

Returns

The quantized model instance.

Return type

SELF

shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#

Shards the model’s parameters according to the specified rules and mesh.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses config rules. Defaults to None.

  • mesh (tp.Optional[Mesh], optional) – JAX device mesh. If None, uses config mesh. Defaults to None.

  • overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional) – Additional functions to apply, potentially overriding default sharding for specific parameters. Defaults to None.

Returns

The sharded model instance.

Return type

SELF

split_lora_params() Dict[source]#

Splits merged LoRA parameters back out from the base model’s parameters.

This function assumes LoRA parameters were previously merged using merge_lora_params or a similar process that stored the original base weights and LoRA weights appropriately.

Returns

A pytree containing the extracted LoRA parameters (A and B matrices).

The base model parameters are restored to their original (pre-merge) state.

Return type

tp.Dict

split_module()[source]#
split_params()[source]#

Splits the module and returns the parameter state.

Uses nnx.split to extract the GraphState containing the parameters.

Returns

The parameter state of the module.

Return type

nn.GraphState

split_params_dict(extract_fn: Optional[Callable] = None, remove_none: bool = True) Dict[source]#

Splits the module parameters and returns them as a nested dictionary.

Extracts the parameter state, converts it to a plain dictionary (removing VariableState wrappers), and optionally removes entries with None values.

Parameters
  • extract_fn (tp.Optional[tp.Callable], optional) – A function to apply to each parameter during extraction. Defaults to None.

  • remove_none (bool, optional) – If True, removes key-value pairs where the value is None. Defaults to True.

Returns

A nested dictionary containing the module’s parameters.

Return type

tp.Dict

property static_arguments: Tuple#

Retrieves or computes static arguments needed for the module’s __call__ method.

Uses self.get_static_arguments() and caches the result. Static arguments are typically those that don’t change during execution and can be pre-computed.

Returns

A tuple of static arguments.

Return type

tp.Tuple

to_dtype(dtype: dtype) SELF[source]#

Converts the module’s parameters to the specified data type.

It iterates through the module’s parameters (excluding quantization-related ones) and casts them to the target dtype. It also updates the param_dtype attribute of the module and its submodules if they exist.

Parameters

dtype (jnp.dtype) – The target data type for the parameters.

Returns

The module instance with parameters converted to the specified dtype.

Return type

SELF

to_state() Any[source]#

Converts the current module instance into an EasyDeLState object.

This is useful for saving and managing the model’s state, including parameters and potentially optimizer state (though optimizer state is typically added later).

Returns

An EasyDeLState object representing the current model state.

Return type

EasyDeLState

to_torch(**kwargs)[source]#

Converts the EasyDeL module to its equivalent Hugging Face PyTorch model.

Requires the corresponding PyTorch model class to be available and registered. Uses utility functions to transfer parameters from JAX to PyTorch format.

Parameters

**kwargs – Additional keyword arguments passed to the parameter transformation function.

Returns

The equivalent Hugging Face PyTorch model with loaded weights.

Return type

torch.nn.Module

property transform_fn#

Returns a partial function for transforming PyTorch state dicts to EasyDeL parameters.

This function identifies embedding and LayerNorm layers within the module and creates a transformation function (torch_dict_to_easydel_params) pre-configured with these layer names, the target parameter dtype, and the module’s sharding functions.

Returns

A partial function ready to convert a PyTorch state dict.

Return type

tp.Callable

unwrap_lora_to_layers(verbose: bool = False) SELF[source]#

Reverts the application of LoRA layers, restoring the original linear layers.

Replaces easydel.layers.lora.LoraLinear layers with their original flax.linen.Dense counterparts, discarding the LoRA matrices.

Parameters

verbose (bool, optional) – If True, prints information about which layers are being reverted. Defaults to False.

Returns

The module instance with LoRA layers removed and original layers restored.

Return type

SELF