easydel.infra.mixins.__init__#
- class easydel.infra.mixins.__init__.BaseModuleProtocol[source]#
Bases:
objectProtocol defining the common interface for EasyDeL modules.
- abstract apply_lora_to_layers(lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = True, rngs: Optional[Rngs] = None) SELF[source]#
Applies LoRA (Low-Rank Adaptation) to specified linear layers within a model.
- base_model_prefix: str#
- compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, past_key_values: tp.Optional[TransformerCache] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: bool = True, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[FlaxCausalLMOutput, LossMetrics]#
- compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: bool = True, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[FlaxSequenceClassifierOutput, LossMetrics]
- compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, output_router_logits: tp.Optional[bool] = None, past_key_values: tp.Optional[TransformerCache] = None, return_dict: bool = True, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[MoeModelOutput, LossMetrics]
- compute_loss(input_ids: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, output_router_logits: tp.Optional[bool] = None, past_key_values: tp.Optional[TransformerCache] = None, return_dict: bool = True, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None) tp.Tuple[MoeCausalLMOutput, LossMetrics]
- compute_loss(*, labels: tp.Optional[chex.Array] = None, loss_config: tp.Optional[LossConfig] = None, loss_kwargs: tp.Optional[tp.Dict] = None, **batch) tp.Tuple[tp.Any, LossMetrics]
Helper for @overload to raise when called.
- config: EasyDeLBaseConfig#
- config_class: Type[EasyDeLBaseConfig]#
- abstract gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None)[source]#
Gathers the model’s parameters based on the specified partitioning rules and mesh.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules for gathering.
mesh (jax.sharding.Mesh, optional) – The mesh to gather from.
- Returns
The gathered model.
- Return type
nn.Module
- abstract property graphdef: Union[NodeDef[Node], NodeRef[Node]]#
- abstract property graphother: State[Key, VariableState[Any]]#
- abstract property graphstate: State[Key, VariableState[Any]]#
- abstract classmethod lazy_init(*args, **kwargs) SELF[source]#
initialize the base class with nnx.eval_shape carefully
- abstract merge_lora_params(pytree: Dict) SELF[source]#
Merge Given Pytree (LoRA Params) with current LoRA Module.
- abstract merge_params_dict(params_dict: Dict)[source]#
Merges the model parameters from a dictionary into the current model.
- Parameters
params_dict (tp.Dict) – A dictionary containing the parameters to merge.
- Returns
The model with merged parameters.
- Return type
- abstract property params_sharding: Dict#
return the sharding of the model parameters
- abstract property pure_transform_fn: Callable#
generates a pure transform function for converting torch to easydel module.
- abstract quantize(method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT, block_size: int = 128, quantization_pattern: Optional[str] = None) SELF[source]#
Quantizes the model’s linear layers.
- Parameters
method (EasyDeLQuantizationMethods, optional) – The quantization method to use.
block_size (int, optional) – The block size for quantization.
quantization_pattern (str, optional) – The quantization pattern to use.
- Returns
The quantized model.
- Return type
nn.Module
- abstract shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None)[source]#
Shards the model’s parameters using the specified partitioning rules and mesh.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules for sharding.
mesh (jax.sharding.Mesh, optional) – The mesh to shard across.
- Returns
The sharded model.
- Return type
nn.Module
- abstract split_params_dict(extract_fn: Optional[Callable] = None, remove_none: bool = True) Dict[source]#
Splits the model parameters and returns them as a dictionary, removing VariableState from the tree.
- Parameters
extract_fn (tp.Optional[tp.Callable], optional) – Function to extract values from the parameters.
remove_none (bool, optional) – Whether to remove None values from the dictionary.
- Returns
The dictionary of split parameters.
- Return type
tp.Dict
- abstract property transform_fn: Callable#
generate transform function for converting torch to easydel module.
- class easydel.infra.mixins.__init__.EasyBridgeMixin[source]#
Bases:
PushToHubMixinMixin class for adding bridging functionalities like saving, loading, and pushing models to Hugging Face Hub.
- base_model_prefix: Optional[str] = None#
- classmethod can_generate() bool[source]#
Checks if the model can generate sequences with .generate().
- Returns
True if the model can generate, False otherwise.
- Return type
bool
- config: EasyDeLBaseConfig#
- config_class: Optional[Type[EasyDeLBaseConfig]] = None#
- classmethod from_pretrained(pretrained_model_name_or_path: ~typing.Optional[~typing.Union[str, ~os.PathLike]], sharding_axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, sharding_axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: ~eformer.escale.partition.constraints.PartitionAxis = PartitionAxis(batch_axis=('fsdp', 'dp'), sequence_axis='sp', query_sequence_axis='sp', head_axis='tp', key_sequence_axis='sp', hidden_state_axis='tp', attention_dim_axis=None, bias_head_sequence_axis=None, bias_key_sequence_axis=None, generation_query_sequence_axis=None, generation_head_axis='tp', generation_key_sequence_axis='sp', generation_attention_dim_axis=None), dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT, config_kwargs: ~typing.Optional[dict[str, typing.Any]] = None, partition_rules: ~typing.Optional[~typing.Tuple[~typing.Tuple[str, ~jax._src.partition_spec.PartitionSpec]]] = None, backend: ~typing.Optional[~easydel.infra.etils.EasyDeLBackends] = None, platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = 'jax', shard_fns: ~typing.Optional[dict[typing.Callable]] = None, auto_shard_model: bool = False, verbose: bool = True, mismatch_allowed: bool = True, *model_args, config: ~typing.Optional[~typing.Union[~easydel.infra.base_config.EasyDeLBaseConfig, str, ~os.PathLike]] = None, cache_dir: ~typing.Optional[~typing.Union[str, ~os.PathLike]] = None, force_download: bool = False, local_files_only: bool = False, token: ~typing.Optional[~typing.Union[bool, str]] = None, revision: str = 'main', vebose: bool = True, quantization_platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = None, quantization_method: ~typing.Optional[~easydel.infra.etils.EasyDeLQuantizationMethods] = None, quantization_block_size: int = 128, quantization_pattern: ~typing.Optional[str] = None, quantize_tensors: bool = True, **kwargs)[source]#
Loads an EasyDeL model from a pretrained model or path.
- Parameters
pretrained_model_name_or_path (str, optional) – The name or path of the pretrained model.
sharding_axis_dims (Sequence[int], optional) – The dimensions of sharding axes.
sharding_axis_names (Sequence[str], optional) – The names of sharding axes.
partition_axis (PartitionAxis, optional) – The partition axis configuration.
dtype (dtype, optional) – The data type of the model.
param_dtype (dtype, optional) – The data type of the parameters.
precision (PrecisionLike, optional) – The computation precision.
config_kwargs (dict[str, Any], optional) – Additional configuration parameters.
partition_rules (tuple, optional) – Custom partitioning rules for sharding.
backend (EasyDeLBackends, optional) – The backend to use.
platform (EasyDeLPlatforms, optional) – The platform to use.
shard_fns (dict[Callable], optional) – Custom shard functions for loading checkpoint.
auto_shard_model (bool, optional) – Whether to automatically shard the model.
verbose (bool, optional) – Whether to print verbose messages. Defaults to True.
mismatch_allowed (bool, optional) – If True, allows mismatch in parameters while loading. Defaults to True.
*model_args – Additional arguments for the model.
config (str, optional) – configuration for the model.
cache_dir (str, optional) – The cache directory for the pretrained model.
force_download (bool, optional) – Whether to force download the model.
local_files_only (bool, optional) – Whether to use only local files.
token (str, optional) – The Hugging Face Hub token.
revision (str, optional) – The revision of the model to load.
**kwargs – Additional keyword arguments.
- Returns
The loaded EasyDeL model.
- hf_torch_auto_loader: Optional[Any] = None#
- push_to_hub(repo_id: str, use_temp_dir: Optional[bool] = None, commit_message: Optional[str] = None, private: Optional[bool] = None, token: Optional[Union[bool, str]] = None, create_pr: bool = False, gather_fns: Optional[dict[Callable]] = None, float_dtype: Optional[dtype] = None, verbose: bool = True, mismatch_allowed: bool = True, revision: Optional[str] = None, commit_description: Optional[str] = None) str[source]#
Pushes the model to the Hugging Face Hub.
- Parameters
repo_id (str) – The repository ID on Hugging Face Hub.
params (any) – Model parameters.
use_temp_dir (bool, optional) – If True, uses a temporary directory. Defaults to None
commit_message (str, optional) – The commit message for the push.
private (bool, optional) – If True, creates a private repository.
token (str or bool, optional) – The Hugging Face Hub token.
create_pr (bool, optional) – If True, creates a pull request.
gather_fns (dict[Callable], optional) – Custom gather functions for checkpoint saving.
float_dtype (dtype, optional) – Data type for saving weights.
verbose (bool, optional) – Whether to print verbose messages. Defaults to True.
mismatch_allowed (bool, optional) – If True, allows mismatch in parameters while loading. Defaults to True.
revision (str, optional) – The revision to push to.
commit_description (str, optional) – The commit description for the push.
- Returns
The URL of the created repository.
- Return type
str
- save_pretrained(save_directory: Union[str, PathLike], push_to_hub: bool = False, token: Optional[Union[bool, str]] = None, gather_fns: Optional[dict[Callable]] = None, float_dtype=None, verbose: bool = True, mismatch_allowed: bool = True, enable: Optional[bool] = None, **kwargs)[source]#
Saves the model, its configuration, and optionally pushes it to the Hugging Face Hub.
- Parameters
save_directory (str or PathLike) – The directory where to save the model.
push_to_hub (bool, optional) – If True, pushes the model to the Hugging Face Hub.
token (str or bool, optional) – The Hugging Face Hub token.
gather_fns (dict[Callable], optional) – Custom gather functions for checkpoint saving.
float_dtype (dtype, optional) – Data type for saving weights.
verbose (bool, optional) – Whether to print verbose messages. Defaults to True.
mismatch_allowed (bool, optional) – If True, allows mismatch in parameters while loading. Defaults to True. enable (bool): if True, allows file to be saved (used for multi-host saving models).
**kwargs – Additional keyword arguments for Hugging Face Hub.
- class easydel.infra.mixins.__init__.EasyGenerationMixin[source]#
Bases:
object- base_model_prefix: str#
- config: EasyDeLBaseConfig#
- config_class: Type[EasyDeLBaseConfig]#
- generate(input_ids: Union[Array, ndarray, bool, number], generation_config: Optional[GenerationConfig] = None, prng_key: Optional[Union[Array, ndarray, bool, number]] = None, trace: bool = True, logits_processor: Optional[FlaxLogitsProcessorList] = None, **kwargs)[source]#
Generates sequences of token ids for models with a language modeling head.
- Parameters
input_ids (chex.Array of shape (batch_size, sequence_length)) – The sequence used as a prompt for the generation.
generation_config (~generation.GenerationConfig, optional) – The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them. If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [~generation.GenerationConfig]’s default values, whose documentation should be checked to parameterize generation.
trace (bool, optional, defaults to True) – Whether to trace generation. Setting trace=False should only be used for debugging and will lead to a considerably slower runtime.
logits_processor (`FlaxLogitsProcessorList `, optional) – Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.
kwargs (tp.Dict[str, Any], optional) – Ad hoc parametrization of generate_config and/or additional model-specific kwargs that will be forwarded to the forward function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with decoder_.
- Returns
[~utils.ModelOutput].
- prepare_inputs_for_generation(input_ids, max_length, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.
- Parameters
self – Access variables that belong to the class
input_ids – Pass in the input tokens
max_length – Set the length of the sequence to be generated
attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds
- Returns
A dictionary of the past_key_values, attention_mask and position ids