easydel.infra.mixins.protocol

Contents

easydel.infra.mixins.protocol#

Protocol definitions for EasyDeL base modules.

This module defines the protocol (interface) that all EasyDeL models must implement. It provides the BaseModuleProtocol abstract base class which specifies the required methods and properties for model implementations, along with utility functions for module representation and formatting.

The protocol ensures consistency across different model implementations and provides type hints for better IDE support and type checking. It combines interfaces from the base module, bridge functionality (EasyBridgeMixin), and generation capabilities (EasyGenerationMixin).

Classes:

BaseModuleProtocol: Abstract base class defining the interface for EasyDeL modules

Functions:

return_type_adjuster: Decorator to adjust return types for type checking get_module_repr: Get string representation of module parameters prettify_nnx: Format module structure for display printify_nnx: Create printable representation of NNX modules

Type Aliases:

PartitionLike: Type for partition rule specifications Self: Type variable for self-referencing types

The protocol includes methods for: - Model forward passes and loss computation (overloaded for different model types) - Parameter management (sharding, gathering, quantization, LoRA) - Model I/O (saving, loading, HuggingFace Hub integration) - Text generation (greedy search, sampling, beam search) - Cache management (standard and paged attention) - Framework conversion (PyTorch ↔ JAX/Flax)

Supported model types: - Causal Language Models - Sequence Classification - Mixture of Experts (MoE) - Vision models (CLIP) - Multi-modal models

Example

>>> from easydel.infra.mixins.protocol import BaseModuleProtocol
>>>
>>> class MyModel(BaseModuleProtocol):
...     # Implement required methods
...     def __call__(self, input_ids, ...):
...         # Forward pass implementation
...         pass
...
...     def compute_loss(self, input_ids, labels, ...):
...         # Loss computation
...         pass
...
...     def generate(self, input_ids, ...):
...         # Generation implementation
...         pass
class easydel.infra.mixins.protocol.BaseModuleProtocol[source]#

Bases: object

Protocol defining the common interface for EasyDeL modules.

abstract apply_lora_to_layers(lora_rank: int, lora_pattern: str | None = None, verbose: bool = False, rngs: flax.nnx.rnglib.Rngs | None = None) Self[source]#

Apply LoRA (Low-Rank Adaptation) to specified linear layers within a model.

base_model_prefix: str#
abstract classmethod can_generate() bool[source]#

Checks if the model can generate sequences with .generate().

Returns

True if the model can generate, False otherwise.

compute_loss(input_ids: chex.Array | None = None, labels: chex.Array | None = None, inputs_embeds: chex.Array | None = None, attention_mask: chex.Array | None = None, mask_info: MaskInfo | None = None, position_ids: chex.Array | None = None, segment_ids: chex.Array | None = None, past_key_values: TransformerCache | RaggedPagesCache | None = None, cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, loss_config: LossConfig | None = None, loss_kwargs: dict | None = None) tuple[CausalLMOutput, LossMetrics]#
compute_loss(input_ids: chex.Array | None = None, labels: chex.Array | None = None, inputs_embeds: chex.Array | None = None, attention_mask: chex.Array | None = None, mask_info: MaskInfo | None = None, position_ids: chex.Array | None = None, segment_ids: chex.Array | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, loss_config: LossConfig | None = None, loss_kwargs: dict | None = None) tuple[SequenceClassifierOutput, LossMetrics]
compute_loss(input_ids: chex.Array | None = None, labels: chex.Array | None = None, inputs_embeds: chex.Array | None = None, attention_mask: chex.Array | None = None, mask_info: MaskInfo | None = None, position_ids: chex.Array | None = None, segment_ids: chex.Array | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, output_router_logits: bool | None = None, past_key_values: TransformerCache | RaggedPagesCache | None = None, cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None, loss_config: LossConfig | None = None, loss_kwargs: dict | None = None) tuple[MoeModelOutput, LossMetrics]
compute_loss(input_ids: chex.Array | None = None, labels: chex.Array | None = None, inputs_embeds: chex.Array | None = None, attention_mask: chex.Array | None = None, mask_info: MaskInfo | None = None, position_ids: chex.Array | None = None, segment_ids: chex.Array | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, output_router_logits: bool | None = None, past_key_values: TransformerCache | RaggedPagesCache | None = None, cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None, loss_config: LossConfig | None = None, loss_kwargs: dict | None = None) tuple[MoeCausalLMOutput, LossMetrics]
compute_loss(*, labels: chex.Array | None = None, loss_config: LossConfig | None = None, loss_kwargs: dict | None = None, **batch) tuple[tp.Any, LossMetrics]

Helper for @overload to raise when called.

config: EasyDeLBaseConfig#
config_class: type[easydel.infra.base_config.EasyDeLBaseConfig]#
abstract create_cache_metadata(batch_size: int, max_length: int, pad_token_id: int | None = None)[source]#

Creates the metadata required for initializing a standard (non-paged) KV Cache.

Parameters
  • batch_size – The batch size for which the cache is being configured.

  • max_length – The maximum sequence length the cache needs to support.

  • pad_token_id – The ID of the padding token. If None, it’s inferred.

Returns

An initialized metadata object for a standard KV cache.

abstract create_paged_metadata(hbm_utilization: float, page_size: int, max_model_length: int)[source]#

Creates the static configuration metadata required for initializing a Paged KV Cache.

Parameters
  • hbm_utilization – Target HBM memory utilization fraction.

  • page_size – Number of tokens per page in the paged cache.

  • max_model_length – Maximum sequence length the model will handle.

Returns

An initialized metadata object containing the static configuration for the paged cache.

abstract float(change_runtime_dtype: bool = True)[source]#

Converts Model paramters to float32.

abstract classmethod from_pretrained(pretrained_model_name_or_path: str | os.PathLike | None, **kwargs)[source]#

Loads an EasyDeL model from a pretrained model or path.

Parameters
  • pretrained_model_name_or_path – The name or path of the pretrained model.

  • **kwargs – Additional keyword arguments for loading configuration.

Returns

The loaded EasyDeL model.

abstract gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: jax._src.mesh.Mesh | None = None)[source]#

Gathers the model’s parameters based on the specified partitioning rules and mesh.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules for gathering.

  • mesh (jax.sharding.Mesh, optional) – The mesh to gather from.

Returns

The gathered model.

Return type

nn.Module

abstract generate(input_ids: Union[Array, ndarray, bool, number], generation_config: Any | None = None, prng_key: Optional[Union[Array, ndarray, bool, number]] = None, trace: bool = True, logits_processor: Any | None = None, **kwargs)[source]#

Generates sequences of token ids for models with a language modeling head.

Parameters
  • input_ids – The sequence used as a prompt for the generation.

  • generation_config – The generation configuration to be used as base parametrization.

  • prng_key – Random key for sampling-based generation.

  • trace – Whether to trace generation for better performance.

  • logits_processor – Custom logits processors.

  • **kwargs – Additional generation parameters.

Returns

Generated sequences and optionally scores.

abstract get_static_arguments() tuple[source]#

return static arguments kwargs for jax.jit / ejit

abstract classmethod get_torch_loader()[source]#

Gets the appropriate PyTorch AutoModel loader for this model type.

Returns

The PyTorch AutoModel class for loading this model type.

abstract property graphdef: GraphDef#
abstract property graphother: State[Key, Union[Variable, Array, ndarray, Ref, ArrayRefOutput, NoUpdate]]#
abstract property graphstate: State[Key, Union[Variable, Array, ndarray, Ref, ArrayRefOutput, NoUpdate]]#
abstract half(change_runtime_dtype: bool = True)[source]#

Converts Model paramters to float16.

abstract init_cache(batch_size: int, max_length: int, starts: int | None = None, shardings: dict | None = None, pad_token_id: int | None = None)[source]#

Initializes and returns a standard (non-paged) Key-Value cache.

Parameters
  • batch_size – The batch size for the cache.

  • max_length – The maximum sequence length the cache needs to support.

  • starts – Optional starting positions for the cache sequences.

  • shardings – Optional dictionary specifying sharding configurations.

  • pad_token_id – The ID of the padding token.

Returns

An initialized standard TransformerCache object.

abstract init_ragged_pages(metadata: Any | None = None, page_size: int | None = None, hbm_utilization: float | None = None, max_model_length: int | None = None)[source]#

Initializes and returns the actual Paged Attention KV Cache tensors.

Parameters
  • metadata – An optional pre-configured metadata object.

  • page_size – Number of tokens per page. Required if metadata is None.

  • hbm_utilization – Target HBM usage. Required if metadata is None.

  • max_model_length – Maximum model sequence length. Required if metadata is None.

Returns

An initialized RaggedPagesCache object containing the allocated cache tensors.

abstract classmethod lazy_init(*args, **kwargs) Self[source]#

initialize the base class with nnx.eval_shape carefully

abstract merge_lora_params(pytree: dict) Self[source]#

Merge LoRA (Low-Rank Adaptation) parameters into the base model parameters.

abstract merge_params(tree)[source]#

merge state to the current model

abstract merge_params_dict(params_dict: dict)[source]#

Merges the model parameters from a dictionary into the current model.

abstract property params_sharding: dict#

return the sharding of the model parameters

abstract prepare_inputs_for_call(**kwargs)[source]#

update inputs for calling model

abstract prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pad_token_id: int, starts: int | None = None, shardings: int | None = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None, mask_info: Any | None = None) dict[str, Any][source]#

Sets up the initial inputs required for starting autoregressive generation.

Parameters
  • input_ids – The initial sequence of token IDs.

  • max_length – The maximum sequence length that the KV cache should support.

  • pad_token_id – The ID used for padding tokens.

  • starts – Optional pre-calculated starting positions.

  • shardings – Optional sharding configuration passed to init_cache.

  • attention_mask – An optional mask indicating which tokens should be attended to.

  • token_type_ids – Optional segment IDs for models that use them.

  • mask_info – Optional pre-constructed MaskInfo object.

Returns

A dictionary containing the prepared inputs for generation.

abstract property pure_transform_fn: Callable#

generates a pure transform function for converting torch to easydel module.

abstract push_to_hub(repo_id: str, use_temp_dir: bool | None = None, commit_message: str | None = None, private: bool | None = None, token: bool | str | None = None, create_pr: bool = False, gather_fns: dict[Callable] | None = None, float_dtype: numpy.dtype | None = None, verbose: bool = True, mismatch_allowed: bool = True, revision: str | None = None, commit_description: str | None = None) str[source]#

Pushes the model to the Hugging Face Hub.

Parameters
  • repo_id – The repository ID on Hugging Face Hub.

  • use_temp_dir – If True, uses a temporary directory.

  • commit_message – The commit message for the push.

  • private – If True, creates a private repository.

  • token – The Hugging Face Hub token.

  • create_pr – If True, creates a pull request.

  • gather_fns – Custom gather functions for checkpoint saving.

  • float_dtype – Data type for saving weights.

  • verbose – Whether to print verbose messages.

  • mismatch_allowed – If True, allows mismatch in parameters while loading.

  • revision – The revision to push to.

  • commit_description – The commit description for the push.

Returns

The URL of the created repository.

abstract quantize(quantization_config: easydel.layers.quantization.quantizers.EasyDeLQuantizationConfig | None = None, quantize_tensors: bool = True, verbose: bool | None = None)[source]#

Quantizes the model’s linear layers.

Parameters
  • quantization_config – Quantization configuration. Pass None to use default INT8.

  • quantize_tensors – Whether to quantize tensors directly.

  • verbose – Whether to print verbose output.

Returns

The quantized model.

abstract save_pretrained(save_directory: str | os.PathLike, push_to_hub: bool = False, token: str | bool | None = None, gather_fns: dict[Callable] | None = None, float_dtype: numpy.dtype | None = None, step: int | None = None, **kwargs)[source]#

Saves the model, its configuration, and optionally pushes it to the Hugging Face Hub.

Parameters
  • save_directory – Directory where to save the model.

  • push_to_hub – If True, pushes the model to the Hugging Face Hub.

  • token – The Hugging Face Hub token.

  • gather_fns – Custom gather functions for checkpoint saving.

  • float_dtype – Data type for saving weights.

  • step – Optional step number for checkpoint naming.

  • **kwargs – Additional keyword arguments for Hugging Face Hub.

abstract shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: jax._src.mesh.Mesh | None = None)[source]#

Shards the model’s parameters using the specified partitioning rules and mesh.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules for sharding.

  • mesh (jax.sharding.Mesh, optional) – The mesh to shard across.

Returns

The sharded model.

Return type

nn.Module

abstract split_lora_params() dict[source]#

Split LoRA (Low-Rank Adaptation) parameters from the base model parameters.

abstract split_params()[source]#

split the model parameters

abstract split_params_dict(params_dict: dict) dict[source]#

Splits the model parameters from a dictionary into separate state components.

abstract to_dtype(dtype) Self[source]#

Converts Model paramters to given dtype

abstract to_state() EasyDeLState[source]#

converts current model to a EasyDeLState

abstract to_torch() PreTrainedModel[source]#

converts current model to a huggingface torch model

abstract property transform_fn: Callable#

generate transform function for converting torch to easydel module.

abstract unwrap_lora_to_layers(verbose: bool = False)[source]#

UnWrap LoRA (Low-Rank Adaptation) from specified linear layers within a model.

abstract update_inputs_for_generation(model_outputs: Any, model_kwargs: dict[str, Any]) dict[str, Any][source]#

Updates the keyword arguments for the next generation step.

Parameters
  • model_outputs – The output object from the model’s forward pass in the previous step.

  • model_kwargs – The dictionary of keyword arguments used for the model call.

Returns

The updated model_kwargs dictionary ready for the next generation step.

easydel.infra.mixins.protocol.get_module_repr(module: Module) str[source]#

Get a string representation of module parameters.

easydel.infra.mixins.protocol.prettify_nnx(module: Module, indent: str = '', depth: int = 0, max_depth: int | None = None, module_param=None) str[source]#

Format the structure of a Flax NNX module for display.

Recursively creates a human-readable representation of a module’s structure, similar to PyTorch’s module printing.

Parameters
  • module – The module to format.

  • indent – Current indentation string.

  • depth – Current recursion depth.

  • max_depth – Maximum depth to recurse.

  • module_param – Optional parameter dictionary.

Returns

Formatted string representation of the module hierarchy.

Example

>>> print(prettify_nnx(my_model, max_depth=2))
MyModel(
  (encoder): Encoder(
    (layers): ModuleList(...)
  )
  (decoder): Decoder(...)
)
easydel.infra.mixins.protocol.printify_nnx(model)[source]#
easydel.infra.mixins.protocol.return_type_adjuster(original_return_type: type[_T]) Callable[[Callable[[...], Module]], Callable[[...], _T]][source]#