easydel.infra.mixins.bridge#
Bridge mixin for EasyDeL-HuggingFace interoperability.
This module provides the EasyBridgeMixin class that enables seamless integration between EasyDeL models and the HuggingFace ecosystem. It handles model serialization, loading from various formats, conversion between frameworks, and integration with the HuggingFace Hub.
The bridge supports: - Loading models from HuggingFace Hub or local paths - Converting between PyTorch and JAX/Flax formats - Saving models in EasyDeL or HuggingFace formats - Pushing models to HuggingFace Hub - Automatic weight format detection and loading - Quantization during loading - Distributed loading with sharding
- Classes:
EasyBridgeMixin: Main mixin class providing bridge functionality
- Constants:
FLAX_WEIGHTS_NAME: Standard name for Flax model weights SAFE_WEIGHTS_NAME: Standard name for SafeTensors weights CANDIDATE_FILENAMES: List of possible weight file names to search
Example
>>> from easydel.infra.mixins import EasyBridgeMixin
>>>
>>> class MyModel(EasyDeLBaseModule, EasyBridgeMixin):
... pass
>>>
>>> # Load from HuggingFace
>>> model = MyModel.from_pretrained("gpt2")
>>>
>>> # Save locally
>>> model.save_pretrained("./my_model")
>>>
>>> # Push to Hub
>>> model.push_to_hub("username/my-model")
- class easydel.infra.mixins.bridge.EasyBridgeMixin[source]#
Bases:
PushToHubMixinMixin class for adding bridging functionalities like saving, loading, and pushing models to Hugging Face Hub.
- classmethod can_generate() bool[source]#
Checks if the model can generate sequences with .generate().
- Returns
True if the model can generate, False otherwise.
- Return type
bool
- config: EasyDeLBaseConfig#
- config_class: type[easydel.infra.base_config.EasyDeLBaseConfig] | None = None#
- classmethod from_pretrained(pretrained_model_name_or_path: str | os.PathLike | None, sharding_axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1, 1), sharding_dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, sharding_axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'ep', 'tp', 'sp'), partition_axis: ~eformer.escale.partition.manager.PartitionAxis = PartitionAxis(data_parallel_axis='dp', fully_sharded_data_parallel_axis='fsdp', tensor_parallel_axis='tp', sequence_parallel_axis='sp', expert_parallel_axis='ep', batch_axis=('fsdp', 'dp'), sequence_axis='sp', query_sequence_axis='sp', head_axis='tp', kv_head_axis='tp', key_sequence_axis='sp', hidden_state_axis='tp', mlp_intermediate_axis='tp', vocab_axis='tp', expert_axis='ep', expert_gate_axis=None, attention_dim_axis=None, attention_kv_dim_axis=None, bias_head_sequence_axis=None, bias_key_sequence_axis=None, decode_batch_axis=('fsdp', 'dp'), decode_query_sequence_axis=None, decode_head_axis='tp', decode_kv_head_axis='tp', decode_key_sequence_axis='sp', decode_attention_dim_axis=None, decode_attention_kv_dim_axis=None), dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT, config_kwargs: dict[str, typing.Any] | None = None, partition_rules: tuple[tuple[str, jax.sharding.PartitionSpec]] | None = None, backend: easydel.infra.etils.EasyDeLBackends | None = None, platform: easydel.infra.etils.EasyDeLPlatforms | None = 'jax', shard_fns: dict[typing.Callable] | None = None, auto_shard_model: bool = True, verbose: bool = True, mismatch_allowed: bool = True, *model_args, config: easydel.infra.base_config.EasyDeLBaseConfig | str | os.PathLike | None = None, cache_dir: str | os.PathLike | None = None, force_download: bool = False, local_files_only: bool = False, token: str | bool | None = None, revision: str = 'main', vebose: bool = True, quantization_config: easydel.layers.quantization.quantizers.EasyDeLQuantizationConfig | None = None, quantize_tensors: bool = True, **kwargs)[source]#
Loads an EasyDeL model from a pretrained model or path.
- Parameters
pretrained_model_name_or_path – The name or path of the pretrained model.
sharding_axis_dims – The dimensions of sharding axes.
sharding_axis_names – The names of sharding axes.
partition_axis – The partition axis configuration.
dtype – The data type of the model.
param_dtype – The data type of the parameters.
precision – The computation precision.
config_kwargs – Additional configuration parameters.
partition_rules – Custom partitioning rules for sharding.
backend – The backend to use.
platform – The platform to use.
shard_fns – Custom shard functions for loading checkpoint.
auto_shard_model – Whether to automatically shard the model.
verbose – Whether to print verbose messages. Defaults to True.
mismatch_allowed – If True, allows mismatch in parameters while loading. Defaults to True.
quantization_config – Quantization config for loading. Pass None to disable.
quantize_tensors – Whether to quantize tensors during loading.
**kwargs – Additional keyword arguments.
- Returns
The loaded EasyDeL model.
- push_to_hub(repo_id: str, use_temp_dir: bool | None = None, commit_message: str | None = None, private: bool | None = None, token: bool | str | None = None, create_pr: bool = False, gather_fns: dict[Callable] | None = None, float_dtype: numpy.dtype | None = None, verbose: bool = True, mismatch_allowed: bool = True, revision: str | None = None, commit_description: str | None = None) str[source]#
Pushes the model to the Hugging Face Hub.
- Parameters
repo_id (str) – The repository ID on Hugging Face Hub.
params (any) – Model parameters.
use_temp_dir (bool, optional) – If True, uses a temporary directory. Defaults to None
commit_message (str, optional) – The commit message for the push.
private (bool, optional) – If True, creates a private repository.
token (str or bool, optional) – The Hugging Face Hub token.
create_pr (bool, optional) – If True, creates a pull request.
gather_fns (dict[Callable], optional) – Custom gather functions for checkpoint saving.
float_dtype (dtype, optional) – Data type for saving weights.
verbose (bool, optional) – Whether to print verbose messages. Defaults to True.
mismatch_allowed (bool, optional) – If True, allows mismatch in parameters while loading. Defaults to True.
revision (str, optional) – The revision to push to.
commit_description (str, optional) – The commit description for the push.
- Returns
The URL of the created repository.
- Return type
str
- save_pretrained(save_directory: str | os.PathLike, push_to_hub: bool = False, token: str | bool | None = None, gather_fns: dict[Callable] | None = None, float_dtype: numpy.dtype | None = None, step: int | None = None, **kwargs)[source]#
Saves the model, its configuration, and optionally pushes it to the Hugging Face Hub.
- Parameters
save_directory (str or PathLike) – The directory where to save the model.
push_to_hub (bool, optional) – If True, pushes the model to the Hugging Face Hub.
token (str or bool, optional) – The Hugging Face Hub token.
gather_fns (dict[Callable], optional) – Custom gather functions for checkpoint saving.
float_dtype (dtype, optional) – Data type for saving weights.
**kwargs – Additional keyword arguments for Hugging Face Hub.