easydel.infra.mixins.protocol#
- class easydel.infra.mixins.protocol.BaseModuleProtocol[source]#
Bases:
objectProtocol defining the common interface for EasyDeL modules.
- abstract apply_lora_to_layers(lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = False, rngs: Optional[Rngs] = None) SELF[source]#
Apply LoRA (Low-Rank Adaptation) to specified linear layers within a model.
- base_model_prefix: str#
- compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None, cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[CausalLMOutput, LossMetrics]#
- compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[SequenceClassifierOutput, LossMetrics]
- compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, output_router_logits: tp.Optional[bool] = None, past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None, cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[MoeModelOutput, LossMetrics]
- compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, output_router_logits: tp.Optional[bool] = None, past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None, cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[MoeCausalLMOutput, LossMetrics]
- compute_loss(*, labels: tp.Optional[chex.Array] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None, **batch) tp.Tuple[tp.Any, LossMetrics]
Helper for @overload to raise when called.
- config: EasyDeLBaseConfig#
- config_class: Type[EasyDeLBaseConfig]#
- abstract gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None)[source]#
Gathers the model’s parameters based on the specified partitioning rules and mesh.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules for gathering.
mesh (jax.sharding.Mesh, optional) – The mesh to gather from.
- Returns
The gathered model.
- Return type
nn.Module
- abstract property graphdef: Union[NodeDef[Node], NodeRef[Node]]#
- abstract property graphother: State[Key, VariableState[Any]]#
- abstract property graphstate: State[Key, VariableState[Any]]#
- abstract classmethod lazy_init(*args, **kwargs) SELF[source]#
initialize the base class with nnx.eval_shape carefully
- abstract merge_lora_params(pytree: Dict) SELF[source]#
Merge LoRA (Low-Rank Adaptation) parameters into the base model parameters.
- abstract merge_params_dict(params_dict: Dict)[source]#
Merges the model parameters from a dictionary into the current model.
- abstract property params_sharding: Dict#
return the sharding of the model parameters
- abstract property pure_transform_fn: Callable#
generates a pure transform function for converting torch to easydel module.
- abstract quantize(method: EasyDeLQuantizationMethods, skip_modules: list[str] | None = None, verbose: bool = False, **kwargs)[source]#
Quantizes the model’s linear layers.
- Parameters
method (EasyDeLQuantizationMethods, optional) – The quantization method to use.
skip_modules (list[str] | None, optional) – List of module names to skip.
verbose (bool, optional) – Whether to print verbose output.
**kwargs – Additional keyword arguments.
- Returns
The quantized model.
- Return type
nn.Module
- abstract shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None)[source]#
Shards the model’s parameters using the specified partitioning rules and mesh.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules for sharding.
mesh (jax.sharding.Mesh, optional) – The mesh to shard across.
- Returns
The sharded model.
- Return type
nn.Module
- abstract split_lora_params() dict[source]#
Split LoRA (Low-Rank Adaptation) parameters from the base model parameters.
- abstract split_params_dict(params_dict: Dict) Dict[source]#
Splits the model parameters from a dictionary into separate state components.
- abstract property transform_fn: Callable#
generate transform function for converting torch to easydel module.
- easydel.infra.mixins.protocol.get_module_repr(module: Module) str[source]#
Get a string representation of module parameters.