easydel.infra.mixins.protocol#

class easydel.infra.mixins.protocol.BaseModuleProtocol[source]#

Bases: object

Protocol defining the common interface for EasyDeL modules.

abstract apply_lora_to_layers(lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = False, rngs: Optional[Rngs] = None) SELF[source]#

Apply LoRA (Low-Rank Adaptation) to specified linear layers within a model.

base_model_prefix: str#
compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None, cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[CausalLMOutput, LossMetrics]#
compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[SequenceClassifierOutput, LossMetrics]
compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, output_router_logits: tp.Optional[bool] = None, past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None, cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[MoeModelOutput, LossMetrics]
compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, output_router_logits: tp.Optional[bool] = None, past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None, cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[MoeCausalLMOutput, LossMetrics]
compute_loss(*, labels: tp.Optional[chex.Array] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None, **batch) tp.Tuple[tp.Any, LossMetrics]

Helper for @overload to raise when called.

config: EasyDeLBaseConfig#
config_class: Type[EasyDeLBaseConfig]#
abstract float(change_runtime_dtype: bool = True)[source]#

Converts Model paramters to float32.

abstract gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None)[source]#

Gathers the model’s parameters based on the specified partitioning rules and mesh.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules for gathering.

  • mesh (jax.sharding.Mesh, optional) – The mesh to gather from.

Returns

The gathered model.

Return type

nn.Module

abstract get_static_arguments() Tuple[source]#

return static arguments kwargs for jax.jit

abstract property graphdef: Union[NodeDef[Node], NodeRef[Node]]#
abstract property graphother: State[Key, VariableState[Any]]#
abstract property graphstate: State[Key, VariableState[Any]]#
abstract half(change_runtime_dtype: bool = True)[source]#

Converts Model paramters to float16.

abstract classmethod lazy_init(*args, **kwargs) SELF[source]#

initialize the base class with nnx.eval_shape carefully

abstract merge_lora_params(pytree: Dict) SELF[source]#

Merge LoRA (Low-Rank Adaptation) parameters into the base model parameters.

abstract merge_params(tree)[source]#

merge state to the current model

abstract merge_params_dict(params_dict: Dict)[source]#

Merges the model parameters from a dictionary into the current model.

abstract property params_sharding: Dict#

return the sharding of the model parameters

abstract prepare_inputs_for_call(**kwargs)[source]#

update inputs for calling model

abstract property pure_transform_fn: Callable#

generates a pure transform function for converting torch to easydel module.

abstract quantize(method: EasyDeLQuantizationMethods, skip_modules: list[str] | None = None, verbose: bool = False, **kwargs)[source]#

Quantizes the model’s linear layers.

Parameters
  • method (EasyDeLQuantizationMethods, optional) – The quantization method to use.

  • skip_modules (list[str] | None, optional) – List of module names to skip.

  • verbose (bool, optional) – Whether to print verbose output.

  • **kwargs – Additional keyword arguments.

Returns

The quantized model.

Return type

nn.Module

abstract shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None)[source]#

Shards the model’s parameters using the specified partitioning rules and mesh.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules for sharding.

  • mesh (jax.sharding.Mesh, optional) – The mesh to shard across.

Returns

The sharded model.

Return type

nn.Module

abstract split_lora_params() dict[source]#

Split LoRA (Low-Rank Adaptation) parameters from the base model parameters.

abstract split_params()[source]#

split the model parameters

abstract split_params_dict(params_dict: Dict) Dict[source]#

Splits the model parameters from a dictionary into separate state components.

abstract to_dtype(dtype) SELF[source]#

Converts Model paramters to given dtype

abstract to_state() Any[source]#

converts current model to a EasyDeLState

abstract to_torch() Any[source]#

converts current model to a huggingface torch model

abstract property transform_fn: Callable#

generate transform function for converting torch to easydel module.

abstract unwrap_lora_to_layers(verbose: bool = False)[source]#

UnWrap LoRA (Low-Rank Adaptation) from specified linear layers within a model.

easydel.infra.mixins.protocol.get_module_repr(module: Module) str[source]#

Get a string representation of module parameters.

easydel.infra.mixins.protocol.prettify_nnx(module: Module, indent: str = '', depth: int = 0, max_depth: int = None, module_param=None) str[source]#

Recursively formats the structure of a Flax NNX module, mimicking PyTorch’s module printing.

easydel.infra.mixins.protocol.printify_nnx(model)[source]#
easydel.infra.mixins.protocol.return_type_adjuster(original_return_type: Type[_T]) Callable[[Callable[[...], Module]], Callable[[...], _T]][source]#