easydel.infra.base_module

Contents

easydel.infra.base_module#

Base module implementation for EasyDeL models.

This module provides the core foundation for all EasyDeL neural network models, implementing essential functionality for model initialization, parameter management, sharding, quantization, and integration with the broader EasyDeL ecosystem.

The EasyDeLBaseModule class serves as the base class that all EasyDeL models inherit from, providing: - Parameter management and state handling - Model sharding and gathering for distributed training - Quantization and LoRA support - Loss computation framework - Integration with HuggingFace models - Generation capabilities through mixins

Key Classes:
EasyDeLBaseModule: The base class for all EasyDeL models, providing common

functionality for parameter handling, sharding, and model operations.

ParameterTransformRule: Data class defining rules for transforming parameter

names and tensors, particularly useful for MoE models.

Example

>>> from easydel.infra import EasyDeLBaseModule, EasyDeLBaseConfig
>>> import flax.nnx as nn
>>>
>>> class MyModel(EasyDeLBaseModule):
...     def __init__(self, config, dtype, param_dtype, precision, rngs):
...         super().__init__(config, dtype, param_dtype, precision, rngs)
...         # Initialize model layers
...         self.layer = nn.Linear(config.hidden_size, config.hidden_size)
...
...     def __call__(self, inputs):
...         return self.layer(inputs)
>>>
>>> # Create and use the model
>>> config = EasyDeLBaseConfig(hidden_size=768)
>>> model = MyModel(
...     config=config,
...     dtype=jnp.float32,
...     param_dtype=jnp.float32,
...     precision='highest',
...     rngs=nn.Rngs(0)
... )

The module integrates with JAX’s sharding system for distributed training, supports various quantization methods, and provides utilities for converting between EasyDeL and HuggingFace model formats.

class easydel.infra.base_module.EasyDeLBaseModule(*args: Any, **kwargs: Any)[source]#

Bases: Module, EasyBridgeMixin, EasyGenerationMixin, BaseModuleProtocol

Base class for EasyDeL modules, providing common functionalities for model initialization, parameter handling, and integration with the EasyDeL ecosystem.

apply_lm_head(hidden_states: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#

Apply the language model head to transform hidden states into logits.

Parameters

hidden_states – Input hidden states from the transformer model. Shape should be […, hidden_size].

Returns

Output logits over the vocabulary. Shape will be […, vocab_size].

apply_lora_to_layers(lora_rank: int, lora_pattern: str | None = None, verbose: bool = False, rngs: flax.nnx.rnglib.Rngs | None = None) Self[source]#

Applies Low-Rank Adaptation (LoRA) layers to the specified linear layers within the module.

Parameters
  • lora_rank (int) – The rank of the LoRA decomposition.

  • lora_pattern (tp.Optional[str], optional) – A regular expression to match the names of the Linear layers to apply LoRA to. If None, applies to common attention and MLP layers. Defaults to None.

  • verbose (bool, optional) – If True, prints information about which layers are being modified. Defaults to False.

  • rngs (tp.Optional[nn.Rngs], optional) – JAX random number generators for initializing LoRA matrices. If None, default RNGs might be used. Defaults to None.

Returns

The module instance with LoRA layers applied.

Return type

Self

apply_out_shardings(out_shardings)[source]#

Applies output sharding specifications to the module state.

Parameters

out_shardings – Sharding specifications to apply.

Returns

Module with sharding constraints applied.

property causal_mask: Array#

Retrieves or computes the basic causal attention mask from the configuration.

Uses self.config.get_basic_causal_mask() and caches the result.

Returns

The causal attention mask, potentially cached.

Return type

jnp.ndarray

compute_complex_rotary(position_ids: Array) Array[source]#

Computes complex-valued rotary position embeddings.

Parameters

position_ids – Position indices to compute embeddings for.

Returns

Complex exponential of frequencies for rotary embeddings.

compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: easydel.infra.loss_utils.LossConfig | None = None, loss_kwargs: dict | None = None, **batch) tuple[Any, easydel.infra.loss_utils.LossMetrics][source]#

Computes the loss for the model given a batch of inputs and labels.

This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.

Parameters
  • labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.

  • loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.

  • loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.

  • **batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).

Returns

A tuple containing:
  • The model’s output ( Pytree typically including logits, hidden states etc.)

  • A LossMetrics object containing the calculated loss and potentially other metrics.

Return type

tp.Tuple[tp.Any, LossMetrics]

Raises
  • AssertionError – If labels are required for the loss function but are not provided or inferred.

  • AssertionError – If sequence classification loss is used without num_labels in the config.

config: BaseConf | None = None#
float(change_runtime_dtype: bool = True) Self[source]#

Converts the module’s parameters to single-precision (float32).

Optionally also changes the runtime computation dtype (self.dtype) to float32.

Parameters

change_runtime_dtype (bool) – If True, also sets self.dtype to jnp.float32. Defaults to True.

Returns

The module instance with parameters (and potentially runtime dtype) set to float32.

Return type

Self

flops_per_token(sequence_length: int | None = None, include_loss: bool = True, include_backward: bool = False) float[source]#

Calculates the total FLOPs (Floating Point Operations) for the module per token.

This method should be implemented by subclasses to provide a module-specific FLOPs calculation.

Returns

The total FLOPs for the module.

Return type

float

Raises

NotImplementedError – If the method is not implemented by the subclass.

property frequencies: Array#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

fully_gather() Self[source]#

Applies JAX sharding constraints to gather all parameters onto the host or a single device.

This function marks all parameters to have no sharding (PartitionSpec()). It uses ejit with out_shardings to enforce these gathering constraints.

Returns

The model instance with gathering constraints applied.

Return type

Self

fully_shard(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None) Self[source]#

Applies JAX sharding constraints to all parameters based on the partition rules.

This function ensures that parameters are explicitly marked with their intended sharding, which can be useful for performance and correctness checks. It uses ejit with out_shardings to enforce the constraints.

Parameters

partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses config rules. Defaults to None.

Returns

The model instance with sharding constraints applied.

Return type

Self

gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: jax._src.mesh.Mesh | None = None, overlay_fns: Optional[Mapping[str, Callable]] = None) Self[source]#

Gathers the model’s parameters from potentially distributed devices to the host or a single device.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules used to determine how parameters were originally sharded. If None, uses config rules. Defaults to None.

  • mesh (tp.Optional[Mesh], optional) – JAX device mesh from which to gather. If None, uses config mesh. Defaults to None.

  • overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional) – Additional functions to apply, potentially overriding default gathering for specific parameters. Defaults to None.

Returns

The model instance with gathered parameters.

Return type

Self

get_decoder() flax.nnx.module.Module | easydel.infra.base_module.EasyDeLBaseModule[source]#

Return the decoder component of the model.

This method should be overridden by encoder-decoder models to return their decoder component. Useful for tasks that need access to the decoder separately from the encoder.

Returns

The decoder module.

Return type

nn.Module | EasyDeLBaseModule

Raises

NotImplementedError – If the model does not implement a decoder.

get_embedding() flax.nnx.module.Module | flax.nnx.nn.linear.Embed[source]#

Return the input embedding layer of the model.

This method should be overridden by models to return their token embedding layer. Useful for weight tying or accessing embeddings directly.

Returns

The embedding layer.

Return type

nn.Module | nn.Embed

Raises

NotImplementedError – If the model does not have an embedding layer.

get_encoder() flax.nnx.module.Module | easydel.infra.base_module.EasyDeLBaseModule[source]#

Return the encoder component of the model.

This method should be overridden by encoder-decoder models to return their encoder component. Useful for tasks that only need the encoder, such as feature extraction or embedding generation.

Returns

The encoder module.

Return type

nn.Module | EasyDeLBaseModule

Raises

NotImplementedError – If the model does not implement an encoder.

get_lm_head() ParallelLinear[source]#

Return the language model head of the model.

This method should be overridden by language models to return their output projection layer that maps hidden states to vocabulary logits.

Returns

The language model head layer.

Return type

ParallelLinear

Raises

NotImplementedError – If the model does not have a language model head.

get_static_arguments() tuple[source]#

Returns a tuple of static arguments required by the module’s __call__ method.

Static arguments are those that don’t change across calls and can be potentially cached or handled differently by JIT compilation. This base implementation returns an empty tuple. Subclasses should override this if they have static arguments.

Returns

A tuple containing static arguments.

Return type

tp.Tuple

property graphdef: GraphDef#

Returns the graph definition (structure without parameters) of the module.

Uses flax.nnx.split to separate the graph definition from the state (parameters).

Returns

The graph definition of the module.

Return type

nn.GraphDef

property graphother: State[Key, Union[Variable, Array, ndarray, Ref, ArrayRefOutput, NoUpdate]]#

Returns any other state variables in the module (non-parameters).

Uses flax.nnx.split to separate non-parameter state variables.

Returns

The graph state containing non-parameter variables.

Return type

nn.GraphState

property graphstate: State[Key, Union[Variable, Array, ndarray, Ref, ArrayRefOutput, NoUpdate]]#

Returns the graph state (parameters) of the module.

Uses flax.nnx.split to separate the state (parameters) from the graph definition.

Returns

The graph state containing the module’s parameters.

Return type

nn.GraphState

property graphstate_type#

Determines the parameter type based on whether LoRA is enabled.

Returns

nn.LoRAParam if LoRA is enabled, otherwise nn.Param.

property graphtree_params_shape: dict#

Computes and returns the shapes of the module’s parameters as a nested dictionary.

It uses nnx.eval_shape to determine the shapes without actual computation, then extracts the shape information from the resulting graph state.

Returns

A nested dictionary mirroring the parameter structure, containing their shapes.

Return type

tp.Dict

property graphtree_shape: dict#

Computes and returns the shapes of all state variables (including non-parameters) in the module.

Uses nnx.eval_shape on the entire module state (parameters and others) and extracts the shape information.

Returns

A nested dictionary mirroring the module’s state structure, containing the shapes.

Return type

tp.Dict

half(change_runtime_dtype: bool = True) Self[source]#

Converts the module’s parameters to half-precision (float16).

Optionally also changes the runtime computation dtype (self.dtype) to float16.

Parameters

change_runtime_dtype (bool) – If True, also sets self.dtype to jnp.float16. Defaults to True.

Returns

The module instance with parameters (and potentially runtime dtype) set to float16.

Return type

Self

property inv_frequencies: Array#

Retrieves or computes the inv-frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_inv_frequencies() and caches the result.

Returns

The inv-frequency components, potentially cached.

Return type

jnp.ndarray

property is_quantized: bool#

Check if the model contains any quantized layers or parameters.

Iterates through the model graph to detect quantized components, including 8-bit linear layers, NF4 linear layers, and quantized arrays.

Returns

True if the model contains any quantized components, False otherwise.

Return type

bool

classmethod lazy_init(**kwargs) Self[source]#

Performs a “lazy” initialization using nnx.eval_shape.

This initializes the module structure and determines parameter shapes without actually allocating memory for the parameters. Useful for inspecting the model structure or preparing for sharding.

Parameters

**kwargs – Keyword arguments passed to the class constructor.

Returns

A module instance with initialized structure but potentially abstract parameters.

Return type

Self

property lora_is_enabled#

Checks if LoRA (Low-Rank Adaptation) is enabled for this module.

Returns

True if any LoRA parameters are found in the module, False otherwise.

property loss_function#

Determines and returns the appropriate loss function based on the configuration or model type.

It prioritizes config.loss_type, then self.loss_type, and finally tries to infer the loss type from the class name. If no suitable loss function is found, it defaults to ForCausalLMLoss and issues a warning.

Returns

The selected loss function (e.g., ForCausalLMLoss, ForSequenceClassificationLoss).

Return type

tp.Callable

property lossfn_type#

Determines the loss function type for this model.

Attempts to determine the loss type from (in order): 1. config.loss_type attribute 2. self.loss_type attribute 3. Class name matching 4. Defaults to ForCausalLM if not found

Returns

String identifier for the loss function type.

merge_lora_params(pytree: dict) Self[source]#

Merges LoRA parameters from a pytree into the base model’s parameters.

Parameters

pytree (tp.Dict) – A dictionary (pytree) containing the LoRA parameters (A and B matrices) structured similarly to the base model’s parameters.

Returns

The module instance with LoRA parameters merged into the base weights.

Return type

Self

static merge_module(graphdef: GraphDef, graphstate: State[Key, Union[Variable, Array, ndarray, Ref, ArrayRefOutput, NoUpdate]], graphother: State[Key, Union[Variable, Array, ndarray, Ref, ArrayRefOutput, NoUpdate]])[source]#

Merges graph components back into a complete module.

Parameters
  • graphdef – The module’s graph definition (structure).

  • graphstate – The module’s parameter state.

  • graphother – The module’s non-parameter state.

Returns

The reconstructed module.

merge_params(tree)[source]#

Merges a given parameter state tree back into the module.

Reconstructs the module using its existing graph definition and ‘other’ state, but replaces the parameter state with the provided tree.

Parameters

tree – A pytree (likely a nn.GraphState) containing the parameters to merge.

Returns

The module instance with the new parameters merged in.

Return type

EasyDeLBaseModule

merge_params_dict(params_dict: dict) Self[source]#

Merges parameters from a dictionary back into the module’s state.

Updates the module’s current parameter state with values from the provided dictionary.

Parameters

params_dict (tp.Dict) – A nested dictionary containing the parameters to merge. The structure should match the module’s parameter structure.

Returns

The module instance with the parameters from the dictionary merged in.

Return type

Self

Raises

KeyError – If a key from params_dict is not found in the module’s current state.

property mesh: Mesh#

Retrieves the JAX device mesh from the module’s configuration.

Returns

The device mesh defined in self.config.mesh.

Return type

jax.sharding.Mesh

property model_task: str | None#

Returns the specific task associated with this model instance (e.g., ‘causal-language-model’).

Returns

The model task identifier, or None if not set.

Return type

tp.Optional[str]

property model_type: str | None#

Returns the specific type of this model instance (e.g., ‘llama’, ‘mistral’).

Returns

The model type identifier, or None if not set.

Return type

tp.Optional[str]

property module_dtype: dtype#

Determines the data type of the module’s parameters.

It inspects the flattened parameter state to find the dtype of the first parameter encountered.

Returns

The data type of the module’s parameters.

Return type

jnp.dtype

new_graphdef(**kwargs: Unpack[EasyDeLBaseConfigDict])[source]#

Create a new module with updated configuration.

Creates a new lazy module with updated configuration while preserving the current parameter state. This is useful for modifying model behavior without reinitializing weights.

Parameters

**kwargs – Configuration parameters to update. These will be applied to a copy of the current configuration.

Returns

A new module instance with updated configuration

and the same parameter values.

Return type

EasyDeLBaseModule

property parameters: dict#

Retrieves the parameters of the module as a dictionary.

This property iterates through the module and its submodules, extracting variables marked as nn.Param and returning them in a flat dictionary where keys represent the parameter path.

Returns

A dictionary containing the module’s parameters.

Return type

tp.Dict

property params: dict#

Returns the parameters and other state variables of the module as a dictionary.

Uses flax.nnx.split to get the combined state (parameters and others).

Returns

A dictionary containing all state variables of the module.

Return type

tp.Dict

property params_sharding: dict#

Retrieves the sharding annotation for each parameter in the module.

Returns

A nested dictionary mirroring the parameter structure, containing the

sharding information (e.g., NamedSharding, PartitionSpec) for each parameter, or None if unsharded.

Return type

tp.Dict

prepare_inputs_for_call(**kwargs)[source]#

Prepares keyword arguments before passing them to the module’s __call__ method.

This base implementation simply returns the kwargs as is. Subclasses can override this to modify or add arguments as needed (e.g., for generation).

Parameters

**kwargs – The keyword arguments intended for __call__.

Returns

The prepared keyword arguments.

Return type

dict

property pure_transform_fn#

Returns a pure transformation function for PyTorch state dicts to EasyDeL parameters.

Similar to transform_fn, but this version does not include sharding functions. It identifies embedding and LayerNorm layers and returns a partial function (torch_dict_to_easydel_params) configured only with layer names and dtype.

Returns

A partial function for converting a PyTorch state dict without applying sharding.

Return type

tp.Callable

quantize(quantization_config: EasyDeLQuantizationConfig | None = None, quantize_tensors: bool = True, verbose: bool | None = None) Self[source]#

Applies quantization to the module’s linear layers or tensors.

Parameters
  • quantization_config – The quantization configuration specifying dtype, block_size, and pattern. If None, uses default INT8 quantization.

  • quantize_tensors – If True, quantizes the tensor values directly. If False, replaces Linear layers with their quantized equivalents.

  • verbose – If True, logs information during the quantization process. Defaults to True only on process index 0.

Returns

The quantized model instance.

Return type

Self

Example

>>> from easydel.layers.quantization import EasyDeLQuantizationConfig, QuantizationType
>>> config = EasyDeLQuantizationConfig(dtype=QuantizationType.NF4, block_size=64)
>>> quantized_model = model.quantize(quantization_config=config)
classmethod sequential_init(**kwargs) Self[source]#

Initialize model parameters sequentially with proper sharding.

This method performs lazy initialization followed by sequential parameter initialization with appropriate sharding for distributed training. It’s particularly useful for large models that need memory-efficient initialization.

The method: 1. Creates a lazy (shape-only) version of the model 2. Iterates through all modules and initializes their parameters 3. Applies proper sharding based on partition rules

Parameters

**kwargs – Arguments passed to lazy_init, including: - config: Model configuration - dtype: Computation dtype - param_dtype: Parameter dtype - precision: JAX precision setting - rngs: Random number generators (required)

Returns

Fully initialized model with properly sharded parameters.

Return type

Self

Example

>>> config = LlamaConfig(hidden_size=1024, num_hidden_layers=4)
>>> model = LlamaModel.sequential_init(
...     config=config,
...     dtype=jnp.float32,
...     param_dtype=jnp.float32,
...     rngs=nn.Rngs(0)
... )
shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: jax._src.mesh.Mesh | None = None, overlay_fns: Optional[Mapping[str, Callable]] = None) Self[source]#

Shards the model’s parameters according to the specified rules and mesh.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses config rules. Defaults to None.

  • mesh (tp.Optional[Mesh], optional) – JAX device mesh. If None, uses config mesh. Defaults to None.

  • overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional) – Additional functions to apply, potentially overriding default sharding for specific parameters. Defaults to None.

Returns

The sharded model instance.

Return type

Self

split_lora_params() dict[source]#

Splits merged LoRA parameters back out from the base model’s parameters.

This function assumes LoRA parameters were previously merged using merge_lora_params or a similar process that stored the original base weights and LoRA weights appropriately.

Returns

A pytree containing the extracted LoRA parameters (A and B matrices).

The base model parameters are restored to their original (pre-merge) state.

Return type

tp.Dict

split_module()[source]#

Splits the module into graph definition and state components.

Returns

Tuple of (GraphDef, GraphState, GraphState) representing structure, parameters, and other state.

split_params()[source]#

Splits the module and returns the parameter state.

Uses nnx.split to extract the GraphState containing the parameters.

Returns

The parameter state of the module.

Return type

nn.GraphState

split_params_dict(extract_fn: Optional[Callable] = None, remove_none: bool = True) dict[source]#

Splits the module parameters and returns them as a nested dictionary.

Extracts the parameter state, converts it to a plain dictionary (removing VariableState wrappers), and optionally removes entries with None values.

Parameters
  • extract_fn (tp.Optional[tp.Callable], optional) – A function to apply to each parameter during extraction. Defaults to None.

  • remove_none (bool, optional) – If True, removes key-value pairs where the value is None. Defaults to True.

Returns

A nested dictionary containing the module’s parameters.

Return type

tp.Dict

property static_arguments: tuple#

Retrieves or computes static arguments needed for the module’s __call__ method.

Uses self.get_static_arguments() and caches the result. Static arguments are typically those that don’t change during execution and can be pre-computed.

Returns

A tuple of static arguments.

Return type

tp.Tuple

static_hash(pop_things: list[str] | None = None)[source]#

Computes a deterministic hash of the module’s state and configuration.

This method creates a hash based on the module’s parameters (graphstate), non-parameter state (graphother), and configuration dictionary. It’s useful for caching compiled functions or identifying when model states differ.

Parameters

pop_things – Optional list of configuration keys to exclude from the hash. Useful when you want to ignore certain config fields (e.g., ‘attn_mechanism’) that shouldn’t affect the cache key.

Returns

A signed integer hash value computed from the model’s state and configuration.

Example

>>> # Hash without excluding any config keys
>>> hash1 = model.static_hash()
>>>
>>> # Hash excluding attention mechanism from consideration
>>> hash2 = model.static_hash(["attn_mechanism"])
>>>
>>> # These may be equal if only attn_mechanism differs
>>> hash1 == hash2  # True if configs match except attn_mechanism

Note

The hash is computed using MD5 on the serialized state signature and configuration dictionary, ensuring deterministic results for identical states.

to_dtype(dtype: dtype) Self[source]#

Converts the module’s parameters to the specified data type.

It iterates through the module’s parameters (excluding quantization-related ones) and casts them to the target dtype. It also updates the param_dtype attribute of the module and its submodules if they exist.

Parameters

dtype (jnp.dtype) – The target data type for the parameters.

Returns

The module instance with parameters converted to the specified dtype.

Return type

Self

to_state(state_class: type[EasyDeLState] | None = None) EasyDeLState[source]#

Convert the current module instance into an EasyDeLState object.

This is useful for saving and managing the model’s state, including parameters and potentially optimizer state (though optimizer state is typically added later).

Parameters

state_class – Optional custom state class to use. If None, defaults to EasyDeLState. Must be a subclass of EasyDeLState.

Returns

An EasyDeLState object representing the current model state,

with step initialized to 0.

Return type

EasyDeLState

to_torch(**kwargs)[source]#

Converts the EasyDeL module to its equivalent Hugging Face PyTorch model.

Requires the corresponding PyTorch model class to be available and registered. Uses utility functions to transfer parameters from JAX to PyTorch format.

Parameters

**kwargs – Additional keyword arguments passed to the parameter transformation function.

Returns

The equivalent Hugging Face PyTorch model with loaded weights.

Return type

torch.nn.Module

property transform_fn#

Creates a transformation function for converting HuggingFace state dicts to EasyDeL format.

Identifies special layers (embeddings, LayerNorm, MoE) and returns a configured transformation function with sharding rules applied.

Returns

Partial function for state dict conversion with layer information and sharding.

unwrap_lora_to_layers(verbose: bool = False) Self[source]#

Reverts the application of LoRA layers, restoring the original linear layers.

Replaces easydel.layers.lora.LoraLinear layers with their original flax.nnx.Linear counterparts, discarding the LoRA matrices.

Parameters

verbose (bool, optional) – If True, prints information about which layers are being reverted. Defaults to False.

Returns

The module instance with LoRA layers removed and original layers restored.

Return type

Self

update_module(**kwargs: Unpack[EasyDeLBaseConfigDict])[source]#

Updates the module configuration and reinitializes the structure.

Creates a new lazy module with updated configuration while preserving the current parameter state.

Parameters

**kwargs – Configuration parameters to update.

Returns

Updated module with new configuration.

class easydel.infra.base_module.ParameterTransformRule(pattern: str | re.Pattern, replacement: str, tensor_transform: collections.abc.Callable | None = None, consolidate_experts: bool = False)[source]#

Bases: object

Rule for transforming MoE parameter names and tensors.

This dataclass defines transformation rules that can be applied to parameter names and their associated tensor values during model conversion or loading. It’s particularly useful for handling Mixture of Experts (MoE) models where parameter naming conventions may differ between frameworks.

pattern#

Regular expression pattern or string to match parameter names. Can be a compiled Pattern object or a string that will be used for matching.

Type

str | re.Pattern

replacement#

String to replace matched patterns in parameter names. Supports regex replacement syntax (e.g., r’’ for capture groups).

Type

str

tensor_transform#

Optional callable to transform the tensor values. Should take a tensor and return a transformed tensor. If None, no transformation is applied to the tensor values.

Type

collections.abc.Callable | None

consolidate_experts#

Whether to consolidate multiple expert parameters into a single tensor. Useful for MoE models where experts may be stored separately but need to be combined.

Type

bool

Example

>>> # Rule to rename expert layers
>>> rule = ParameterTransformRule(
...     pattern=r"expert_(\d+)\.weight",
...     replacement=r"experts..weight",
...     tensor_transform=lambda x: x.transpose(),
...     consolidate_experts=True
... )
consolidate_experts: bool = False#
pattern: str | re.Pattern#
replacement: str#
tensor_transform: collections.abc.Callable | None = None#