easydel.infra.mixins.bridge#

Bridge mixin for EasyDeL-HuggingFace interoperability.

This module provides the EasyBridgeMixin class that enables seamless integration between EasyDeL models and the HuggingFace ecosystem. It handles model serialization, loading from various formats, conversion between frameworks, and integration with the HuggingFace Hub.

The bridge supports: - Loading models from HuggingFace Hub or local paths - Converting between PyTorch and JAX/Flax formats - Saving models in EasyDeL or HuggingFace formats - Pushing models to HuggingFace Hub - Automatic weight format detection and loading - Quantization during loading - Distributed loading with sharding

Classes:

EasyBridgeMixin: Main mixin class providing bridge functionality

Constants:

FLAX_WEIGHTS_NAME: Standard name for Flax model weights SAFE_WEIGHTS_NAME: Standard name for SafeTensors weights CANDIDATE_FILENAMES: List of possible weight file names to search

Example

>>> from easydel.infra.mixins import EasyBridgeMixin
>>>
>>> class MyModel(EasyDeLBaseModule, EasyBridgeMixin):
...     pass
>>>
>>> # Load from HuggingFace
>>> model = MyModel.from_pretrained("gpt2")
>>>
>>> # Save locally
>>> model.save_pretrained("./my_model")
>>>
>>> # Push to Hub
>>> model.push_to_hub("username/my-model")
class easydel.infra.mixins.bridge.EasyBridgeMixin[source]#

Bases: PushToHubMixin

Mixin class for adding bridging functionalities like saving, loading, and pushing models to Hugging Face Hub.

base_model_prefix: str | None = None#
classmethod can_generate() bool[source]#

Checks if the model can generate sequences with .generate().

Returns

True if the model can generate, False otherwise.

Return type

bool

config: EasyDeLBaseConfig#
config_class: type[easydel.infra.base_config.EasyDeLBaseConfig] | None = None#
classmethod from_pretrained(pretrained_model_name_or_path: str | os.PathLike | None, sharding_axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1, 1), sharding_dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, sharding_axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'ep', 'tp', 'sp'), partition_axis: ~eformer.escale.partition.manager.PartitionAxis = PartitionAxis(data_parallel_axis='dp', fully_sharded_data_parallel_axis='fsdp', tensor_parallel_axis='tp', sequence_parallel_axis='sp', expert_parallel_axis='ep', batch_axis=('fsdp', 'dp'), sequence_axis='sp', query_sequence_axis='sp', head_axis='tp', kv_head_axis='tp', key_sequence_axis='sp', hidden_state_axis='tp', mlp_intermediate_axis='tp', vocab_axis='tp', expert_axis='ep', expert_gate_axis=None, attention_dim_axis=None, attention_kv_dim_axis=None, bias_head_sequence_axis=None, bias_key_sequence_axis=None, decode_batch_axis=('fsdp', 'dp'), decode_query_sequence_axis=None, decode_head_axis='tp', decode_kv_head_axis='tp', decode_key_sequence_axis='sp', decode_attention_dim_axis=None, decode_attention_kv_dim_axis=None), dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT, config_kwargs: dict[str, typing.Any] | None = None, partition_rules: tuple[tuple[str, jax.sharding.PartitionSpec]] | None = None, backend: easydel.infra.etils.EasyDeLBackends | None = None, platform: easydel.infra.etils.EasyDeLPlatforms | None = 'jax', shard_fns: dict[typing.Callable] | None = None, auto_shard_model: bool = True, verbose: bool = True, mismatch_allowed: bool = True, *model_args, config: easydel.infra.base_config.EasyDeLBaseConfig | str | os.PathLike | None = None, cache_dir: str | os.PathLike | None = None, force_download: bool = False, local_files_only: bool = False, token: str | bool | None = None, revision: str = 'main', vebose: bool = True, quantization_config: easydel.layers.quantization.quantizers.EasyDeLQuantizationConfig | None = None, quantize_tensors: bool = True, **kwargs)[source]#

Loads an EasyDeL model from a pretrained model or path.

Parameters
  • pretrained_model_name_or_path – The name or path of the pretrained model.

  • sharding_axis_dims – The dimensions of sharding axes.

  • sharding_axis_names – The names of sharding axes.

  • partition_axis – The partition axis configuration.

  • dtype – The data type of the model.

  • param_dtype – The data type of the parameters.

  • precision – The computation precision.

  • config_kwargs – Additional configuration parameters.

  • partition_rules – Custom partitioning rules for sharding.

  • backend – The backend to use.

  • platform – The platform to use.

  • shard_fns – Custom shard functions for loading checkpoint.

  • auto_shard_model – Whether to automatically shard the model.

  • verbose – Whether to print verbose messages. Defaults to True.

  • mismatch_allowed – If True, allows mismatch in parameters while loading. Defaults to True.

  • quantization_config – Quantization config for loading. Pass None to disable.

  • quantize_tensors – Whether to quantize tensors during loading.

  • **kwargs – Additional keyword arguments.

Returns

The loaded EasyDeL model.

classmethod get_torch_loader()[source]#
hf_torch_auto_loader: Any | None = None#
push_to_hub(repo_id: str, use_temp_dir: bool | None = None, commit_message: str | None = None, private: bool | None = None, token: bool | str | None = None, create_pr: bool = False, gather_fns: dict[Callable] | None = None, float_dtype: numpy.dtype | None = None, verbose: bool = True, mismatch_allowed: bool = True, revision: str | None = None, commit_description: str | None = None) str[source]#

Pushes the model to the Hugging Face Hub.

Parameters
  • repo_id (str) – The repository ID on Hugging Face Hub.

  • params (any) – Model parameters.

  • use_temp_dir (bool, optional) – If True, uses a temporary directory. Defaults to None

  • commit_message (str, optional) – The commit message for the push.

  • private (bool, optional) – If True, creates a private repository.

  • token (str or bool, optional) – The Hugging Face Hub token.

  • create_pr (bool, optional) – If True, creates a pull request.

  • gather_fns (dict[Callable], optional) – Custom gather functions for checkpoint saving.

  • float_dtype (dtype, optional) – Data type for saving weights.

  • verbose (bool, optional) – Whether to print verbose messages. Defaults to True.

  • mismatch_allowed (bool, optional) – If True, allows mismatch in parameters while loading. Defaults to True.

  • revision (str, optional) – The revision to push to.

  • commit_description (str, optional) – The commit description for the push.

Returns

The URL of the created repository.

Return type

str

save_pretrained(save_directory: str | os.PathLike, push_to_hub: bool = False, token: str | bool | None = None, gather_fns: dict[Callable] | None = None, float_dtype: numpy.dtype | None = None, step: int | None = None, **kwargs)[source]#

Saves the model, its configuration, and optionally pushes it to the Hugging Face Hub.

Parameters
  • save_directory (str or PathLike) – The directory where to save the model.

  • push_to_hub (bool, optional) – If True, pushes the model to the Hugging Face Hub.

  • token (str or bool, optional) – The Hugging Face Hub token.

  • gather_fns (dict[Callable], optional) – Custom gather functions for checkpoint saving.

  • float_dtype (dtype, optional) – Data type for saving weights.

  • **kwargs – Additional keyword arguments for Hugging Face Hub.