easydel.infra.mixins.protocol#
Protocol definitions for EasyDeL base modules.
This module defines the protocol (interface) that all EasyDeL models must implement. It provides the BaseModuleProtocol abstract base class which specifies the required methods and properties for model implementations, along with utility functions for module representation and formatting.
The protocol ensures consistency across different model implementations and provides type hints for better IDE support and type checking. It combines interfaces from the base module, bridge functionality (EasyBridgeMixin), and generation capabilities (EasyGenerationMixin).
- Classes:
BaseModuleProtocol: Abstract base class defining the interface for EasyDeL modules
- Functions:
return_type_adjuster: Decorator to adjust return types for type checking get_module_repr: Get string representation of module parameters prettify_nnx: Format module structure for display printify_nnx: Create printable representation of NNX modules
- Type Aliases:
PartitionLike: Type for partition rule specifications Self: Type variable for self-referencing types
The protocol includes methods for: - Model forward passes and loss computation (overloaded for different model types) - Parameter management (sharding, gathering, quantization, LoRA) - Model I/O (saving, loading, HuggingFace Hub integration) - Text generation (greedy search, sampling, beam search) - Cache management (standard and paged attention) - Framework conversion (PyTorch ↔ JAX/Flax)
Supported model types: - Causal Language Models - Sequence Classification - Mixture of Experts (MoE) - Vision models (CLIP) - Multi-modal models
Example
>>> from easydel.infra.mixins.protocol import BaseModuleProtocol
>>>
>>> class MyModel(BaseModuleProtocol):
... # Implement required methods
... def __call__(self, input_ids, ...):
... # Forward pass implementation
... pass
...
... def compute_loss(self, input_ids, labels, ...):
... # Loss computation
... pass
...
... def generate(self, input_ids, ...):
... # Generation implementation
... pass
- class easydel.infra.mixins.protocol.BaseModuleProtocol[source]#
Bases:
objectProtocol defining the common interface for EasyDeL modules.
- abstract apply_lora_to_layers(lora_rank: int, lora_pattern: str | None = None, verbose: bool = False, rngs: flax.nnx.rnglib.Rngs | None = None) Self[source]#
Apply LoRA (Low-Rank Adaptation) to specified linear layers within a model.
- base_model_prefix: str#
- abstract classmethod can_generate() bool[source]#
Checks if the model can generate sequences with .generate().
- Returns
True if the model can generate, False otherwise.
- compute_loss(input_ids: chex.Array | None = None, labels: chex.Array | None = None, inputs_embeds: chex.Array | None = None, attention_mask: chex.Array | None = None, mask_info: MaskInfo | None = None, position_ids: chex.Array | None = None, segment_ids: chex.Array | None = None, past_key_values: TransformerCache | RaggedPagesCache | None = None, cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, loss_config: LossConfig | None = None, loss_kwargs: dict | None = None) tuple[CausalLMOutput, LossMetrics]#
- compute_loss(input_ids: chex.Array | None = None, labels: chex.Array | None = None, inputs_embeds: chex.Array | None = None, attention_mask: chex.Array | None = None, mask_info: MaskInfo | None = None, position_ids: chex.Array | None = None, segment_ids: chex.Array | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, loss_config: LossConfig | None = None, loss_kwargs: dict | None = None) tuple[SequenceClassifierOutput, LossMetrics]
- compute_loss(input_ids: chex.Array | None = None, labels: chex.Array | None = None, inputs_embeds: chex.Array | None = None, attention_mask: chex.Array | None = None, mask_info: MaskInfo | None = None, position_ids: chex.Array | None = None, segment_ids: chex.Array | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, output_router_logits: bool | None = None, past_key_values: TransformerCache | RaggedPagesCache | None = None, cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None, loss_config: LossConfig | None = None, loss_kwargs: dict | None = None) tuple[MoeModelOutput, LossMetrics]
- compute_loss(input_ids: chex.Array | None = None, labels: chex.Array | None = None, inputs_embeds: chex.Array | None = None, attention_mask: chex.Array | None = None, mask_info: MaskInfo | None = None, position_ids: chex.Array | None = None, segment_ids: chex.Array | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, output_router_logits: bool | None = None, past_key_values: TransformerCache | RaggedPagesCache | None = None, cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None, loss_config: LossConfig | None = None, loss_kwargs: dict | None = None) tuple[MoeCausalLMOutput, LossMetrics]
- compute_loss(*, labels: chex.Array | None = None, loss_config: LossConfig | None = None, loss_kwargs: dict | None = None, **batch) tuple[tp.Any, LossMetrics]
Helper for @overload to raise when called.
- config: EasyDeLBaseConfig#
- config_class: type[easydel.infra.base_config.EasyDeLBaseConfig]#
- abstract create_cache_metadata(batch_size: int, max_length: int, pad_token_id: int | None = None)[source]#
Creates the metadata required for initializing a standard (non-paged) KV Cache.
- Parameters
batch_size – The batch size for which the cache is being configured.
max_length – The maximum sequence length the cache needs to support.
pad_token_id – The ID of the padding token. If None, it’s inferred.
- Returns
An initialized metadata object for a standard KV cache.
- abstract create_paged_metadata(hbm_utilization: float, page_size: int, max_model_length: int)[source]#
Creates the static configuration metadata required for initializing a Paged KV Cache.
- Parameters
hbm_utilization – Target HBM memory utilization fraction.
page_size – Number of tokens per page in the paged cache.
max_model_length – Maximum sequence length the model will handle.
- Returns
An initialized metadata object containing the static configuration for the paged cache.
- abstract classmethod from_pretrained(pretrained_model_name_or_path: str | os.PathLike | None, **kwargs)[source]#
Loads an EasyDeL model from a pretrained model or path.
- Parameters
pretrained_model_name_or_path – The name or path of the pretrained model.
**kwargs – Additional keyword arguments for loading configuration.
- Returns
The loaded EasyDeL model.
- abstract gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: jax._src.mesh.Mesh | None = None)[source]#
Gathers the model’s parameters based on the specified partitioning rules and mesh.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules for gathering.
mesh (jax.sharding.Mesh, optional) – The mesh to gather from.
- Returns
The gathered model.
- Return type
nn.Module
- abstract generate(input_ids: Union[Array, ndarray, bool, number], generation_config: Any | None = None, prng_key: Optional[Union[Array, ndarray, bool, number]] = None, trace: bool = True, logits_processor: Any | None = None, **kwargs)[source]#
Generates sequences of token ids for models with a language modeling head.
- Parameters
input_ids – The sequence used as a prompt for the generation.
generation_config – The generation configuration to be used as base parametrization.
prng_key – Random key for sampling-based generation.
trace – Whether to trace generation for better performance.
logits_processor – Custom logits processors.
**kwargs – Additional generation parameters.
- Returns
Generated sequences and optionally scores.
- abstract classmethod get_torch_loader()[source]#
Gets the appropriate PyTorch AutoModel loader for this model type.
- Returns
The PyTorch AutoModel class for loading this model type.
- abstract property graphother: State[Key, Union[Variable, Array, ndarray, Ref, ArrayRefOutput, NoUpdate]]#
- abstract property graphstate: State[Key, Union[Variable, Array, ndarray, Ref, ArrayRefOutput, NoUpdate]]#
- abstract init_cache(batch_size: int, max_length: int, starts: int | None = None, shardings: dict | None = None, pad_token_id: int | None = None)[source]#
Initializes and returns a standard (non-paged) Key-Value cache.
- Parameters
batch_size – The batch size for the cache.
max_length – The maximum sequence length the cache needs to support.
starts – Optional starting positions for the cache sequences.
shardings – Optional dictionary specifying sharding configurations.
pad_token_id – The ID of the padding token.
- Returns
An initialized standard TransformerCache object.
- abstract init_ragged_pages(metadata: Any | None = None, page_size: int | None = None, hbm_utilization: float | None = None, max_model_length: int | None = None)[source]#
Initializes and returns the actual Paged Attention KV Cache tensors.
- Parameters
metadata – An optional pre-configured metadata object.
page_size – Number of tokens per page. Required if metadata is None.
hbm_utilization – Target HBM usage. Required if metadata is None.
max_model_length – Maximum model sequence length. Required if metadata is None.
- Returns
An initialized RaggedPagesCache object containing the allocated cache tensors.
- abstract classmethod lazy_init(*args, **kwargs) Self[source]#
initialize the base class with nnx.eval_shape carefully
- abstract merge_lora_params(pytree: dict) Self[source]#
Merge LoRA (Low-Rank Adaptation) parameters into the base model parameters.
- abstract merge_params_dict(params_dict: dict)[source]#
Merges the model parameters from a dictionary into the current model.
- abstract property params_sharding: dict#
return the sharding of the model parameters
- abstract prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pad_token_id: int, starts: int | None = None, shardings: int | None = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None, mask_info: Any | None = None) dict[str, Any][source]#
Sets up the initial inputs required for starting autoregressive generation.
- Parameters
input_ids – The initial sequence of token IDs.
max_length – The maximum sequence length that the KV cache should support.
pad_token_id – The ID used for padding tokens.
starts – Optional pre-calculated starting positions.
shardings – Optional sharding configuration passed to init_cache.
attention_mask – An optional mask indicating which tokens should be attended to.
token_type_ids – Optional segment IDs for models that use them.
mask_info – Optional pre-constructed MaskInfo object.
- Returns
A dictionary containing the prepared inputs for generation.
- abstract property pure_transform_fn: Callable#
generates a pure transform function for converting torch to easydel module.
- abstract push_to_hub(repo_id: str, use_temp_dir: bool | None = None, commit_message: str | None = None, private: bool | None = None, token: bool | str | None = None, create_pr: bool = False, gather_fns: dict[Callable] | None = None, float_dtype: numpy.dtype | None = None, verbose: bool = True, mismatch_allowed: bool = True, revision: str | None = None, commit_description: str | None = None) str[source]#
Pushes the model to the Hugging Face Hub.
- Parameters
repo_id – The repository ID on Hugging Face Hub.
use_temp_dir – If True, uses a temporary directory.
commit_message – The commit message for the push.
private – If True, creates a private repository.
token – The Hugging Face Hub token.
create_pr – If True, creates a pull request.
gather_fns – Custom gather functions for checkpoint saving.
float_dtype – Data type for saving weights.
verbose – Whether to print verbose messages.
mismatch_allowed – If True, allows mismatch in parameters while loading.
revision – The revision to push to.
commit_description – The commit description for the push.
- Returns
The URL of the created repository.
- abstract quantize(quantization_config: easydel.layers.quantization.quantizers.EasyDeLQuantizationConfig | None = None, quantize_tensors: bool = True, verbose: bool | None = None)[source]#
Quantizes the model’s linear layers.
- Parameters
quantization_config – Quantization configuration. Pass None to use default INT8.
quantize_tensors – Whether to quantize tensors directly.
verbose – Whether to print verbose output.
- Returns
The quantized model.
- abstract save_pretrained(save_directory: str | os.PathLike, push_to_hub: bool = False, token: str | bool | None = None, gather_fns: dict[Callable] | None = None, float_dtype: numpy.dtype | None = None, step: int | None = None, **kwargs)[source]#
Saves the model, its configuration, and optionally pushes it to the Hugging Face Hub.
- Parameters
save_directory – Directory where to save the model.
push_to_hub – If True, pushes the model to the Hugging Face Hub.
token – The Hugging Face Hub token.
gather_fns – Custom gather functions for checkpoint saving.
float_dtype – Data type for saving weights.
step – Optional step number for checkpoint naming.
**kwargs – Additional keyword arguments for Hugging Face Hub.
- abstract shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: jax._src.mesh.Mesh | None = None)[source]#
Shards the model’s parameters using the specified partitioning rules and mesh.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules for sharding.
mesh (jax.sharding.Mesh, optional) – The mesh to shard across.
- Returns
The sharded model.
- Return type
nn.Module
- abstract split_lora_params() dict[source]#
Split LoRA (Low-Rank Adaptation) parameters from the base model parameters.
- abstract split_params_dict(params_dict: dict) dict[source]#
Splits the model parameters from a dictionary into separate state components.
- abstract to_state() EasyDeLState[source]#
converts current model to a EasyDeLState
- abstract property transform_fn: Callable#
generate transform function for converting torch to easydel module.
- abstract unwrap_lora_to_layers(verbose: bool = False)[source]#
UnWrap LoRA (Low-Rank Adaptation) from specified linear layers within a model.
- abstract update_inputs_for_generation(model_outputs: Any, model_kwargs: dict[str, Any]) dict[str, Any][source]#
Updates the keyword arguments for the next generation step.
- Parameters
model_outputs – The output object from the model’s forward pass in the previous step.
model_kwargs – The dictionary of keyword arguments used for the model call.
- Returns
The updated model_kwargs dictionary ready for the next generation step.
- easydel.infra.mixins.protocol.get_module_repr(module: Module) str[source]#
Get a string representation of module parameters.
- easydel.infra.mixins.protocol.prettify_nnx(module: Module, indent: str = '', depth: int = 0, max_depth: int | None = None, module_param=None) str[source]#
Format the structure of a Flax NNX module for display.
Recursively creates a human-readable representation of a module’s structure, similar to PyTorch’s module printing.
- Parameters
module – The module to format.
indent – Current indentation string.
depth – Current recursion depth.
max_depth – Maximum depth to recurse.
module_param – Optional parameter dictionary.
- Returns
Formatted string representation of the module hierarchy.
Example
>>> print(prettify_nnx(my_model, max_depth=2)) MyModel( (encoder): Encoder( (layers): ModuleList(...) ) (decoder): Decoder(...) )