easydel.infra.mixins.protocol#
- class easydel.infra.mixins.protocol.BaseModuleProtocol[source]#
Bases:
objectProtocol defining the common interface for EasyDeL modules.
- abstract apply_lora_to_layers(lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = True, rngs: Optional[Rngs] = None) SELF[source]#
Applies LoRA (Low-Rank Adaptation) to specified linear layers within a model.
- base_model_prefix: str#
- compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, past_key_values: tp.Optional[TransformerCache] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: bool = True, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[FlaxCausalLMOutput, LossMetrics]#
- compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: bool = True, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[FlaxSequenceClassifierOutput, LossMetrics]
- compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, output_router_logits: tp.Optional[bool] = None, past_key_values: tp.Optional[TransformerCache] = None, return_dict: bool = True, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[MoeModelOutput, LossMetrics]
- compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, output_router_logits: tp.Optional[bool] = None, past_key_values: tp.Optional[TransformerCache] = None, return_dict: bool = True, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[MoeCausalLMOutput, LossMetrics]
- compute_loss(*, labels: tp.Optional[chex.Array] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None, **batch) tp.Tuple[tp.Any, LossMetrics]
Helper for @overload to raise when called.
- config: EasyDeLBaseConfig#
- config_class: Type[EasyDeLBaseConfig]#
- abstract gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None)[source]#
Gathers the model’s parameters based on the specified partitioning rules and mesh.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules for gathering.
mesh (jax.sharding.Mesh, optional) – The mesh to gather from.
- Returns
The gathered model.
- Return type
nn.Module
- abstract property graphdef: Union[NodeDef[Node], NodeRef[Node]]#
- abstract property graphother: State[Key, VariableState[Any]]#
- abstract property graphstate: State[Key, VariableState[Any]]#
- abstract classmethod lazy_init(*args, **kwargs) SELF[source]#
initialize the base class with nnx.eval_shape carefully
- abstract merge_lora_params(pytree: Dict) SELF[source]#
Merge Given Pytree (LoRA Params) with current LoRA Module.
- abstract merge_params_dict(params_dict: Dict)[source]#
Merges the model parameters from a dictionary into the current model.
- Parameters
params_dict (tp.Dict) – A dictionary containing the parameters to merge.
- Returns
The model with merged parameters.
- Return type
- abstract property params_sharding: Dict#
return the sharding of the model parameters
- abstract property pure_transform_fn: Callable#
generates a pure transform function for converting torch to easydel module.
- abstract quantize(method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT, block_size: int = 128, quantization_pattern: Optional[str] = None) SELF[source]#
Quantizes the model’s linear layers.
- Parameters
method (EasyDeLQuantizationMethods, optional) – The quantization method to use.
block_size (int, optional) – The block size for quantization.
quantization_pattern (str, optional) – The quantization pattern to use.
- Returns
The quantized model.
- Return type
nn.Module
- abstract shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None)[source]#
Shards the model’s parameters using the specified partitioning rules and mesh.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules for sharding.
mesh (jax.sharding.Mesh, optional) – The mesh to shard across.
- Returns
The sharded model.
- Return type
nn.Module
- abstract split_params_dict(extract_fn: Optional[Callable] = None, remove_none: bool = True) Dict[source]#
Splits the model parameters and returns them as a dictionary, removing VariableState from the tree.
- Parameters
extract_fn (tp.Optional[tp.Callable], optional) – Function to extract values from the parameters.
remove_none (bool, optional) – Whether to remove None values from the dictionary.
- Returns
The dictionary of split parameters.
- Return type
tp.Dict
- abstract property transform_fn: Callable#
generate transform function for converting torch to easydel module.
- easydel.infra.mixins.protocol.get_module_repr(module: Module) str[source]#
Get a string representation of module parameters.