easydel.infra.mixins.bridge#

class easydel.infra.mixins.bridge.EasyBridgeMixin[source]#

Bases: PushToHubMixin

Mixin class for adding bridging functionalities like saving, loading, and pushing models to Hugging Face Hub.

base_model_prefix: Optional[str] = None#
classmethod can_generate() bool[source]#

Checks if the model can generate sequences with .generate().

Returns

True if the model can generate, False otherwise.

Return type

bool

config: EasyDeLBaseConfig#
config_class: Optional[Type[EasyDeLBaseConfig]] = None#
classmethod from_pretrained(pretrained_model_name_or_path: ~typing.Optional[~typing.Union[str, ~os.PathLike]], sharding_axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, sharding_axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: ~eformer.escale.partition.constraints.PartitionAxis = PartitionAxis(batch_axis=('fsdp', 'dp'), sequence_axis='sp', query_sequence_axis='sp', head_axis='tp', key_sequence_axis='sp', hidden_state_axis='tp', attention_dim_axis=None, bias_head_sequence_axis=None, bias_key_sequence_axis=None, generation_query_sequence_axis=None, generation_head_axis='tp', generation_key_sequence_axis='sp', generation_attention_dim_axis=None), dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT, config_kwargs: ~typing.Optional[dict[str, typing.Any]] = None, partition_rules: ~typing.Optional[~typing.Tuple[~typing.Tuple[str, ~jax._src.partition_spec.PartitionSpec]]] = None, backend: ~typing.Optional[~easydel.infra.etils.EasyDeLBackends] = None, platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = 'jax', shard_fns: ~typing.Optional[dict[typing.Callable]] = None, auto_shard_model: bool = False, verbose: bool = True, mismatch_allowed: bool = True, *model_args, config: ~typing.Optional[~typing.Union[~easydel.infra.base_config.EasyDeLBaseConfig, str, ~os.PathLike]] = None, cache_dir: ~typing.Optional[~typing.Union[str, ~os.PathLike]] = None, force_download: bool = False, local_files_only: bool = False, token: ~typing.Optional[~typing.Union[bool, str]] = None, revision: str = 'main', vebose: bool = True, quantization_platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = None, quantization_method: ~typing.Optional[~easydel.infra.etils.EasyDeLQuantizationMethods] = None, quantization_block_size: int = 128, quantization_pattern: ~typing.Optional[str] = None, quantize_tensors: bool = True, **kwargs)[source]#

Loads an EasyDeL model from a pretrained model or path.

Parameters
  • pretrained_model_name_or_path (str, optional) – The name or path of the pretrained model.

  • sharding_axis_dims (Sequence[int], optional) – The dimensions of sharding axes.

  • sharding_axis_names (Sequence[str], optional) – The names of sharding axes.

  • partition_axis (PartitionAxis, optional) – The partition axis configuration.

  • dtype (dtype, optional) – The data type of the model.

  • param_dtype (dtype, optional) – The data type of the parameters.

  • precision (PrecisionLike, optional) – The computation precision.

  • config_kwargs (dict[str, Any], optional) – Additional configuration parameters.

  • partition_rules (tuple, optional) – Custom partitioning rules for sharding.

  • backend (EasyDeLBackends, optional) – The backend to use.

  • platform (EasyDeLPlatforms, optional) – The platform to use.

  • shard_fns (dict[Callable], optional) – Custom shard functions for loading checkpoint.

  • auto_shard_model (bool, optional) – Whether to automatically shard the model.

  • verbose (bool, optional) – Whether to print verbose messages. Defaults to True.

  • mismatch_allowed (bool, optional) – If True, allows mismatch in parameters while loading. Defaults to True.

  • *model_args – Additional arguments for the model.

  • config (str, optional) – configuration for the model.

  • cache_dir (str, optional) – The cache directory for the pretrained model.

  • force_download (bool, optional) – Whether to force download the model.

  • local_files_only (bool, optional) – Whether to use only local files.

  • token (str, optional) – The Hugging Face Hub token.

  • revision (str, optional) – The revision of the model to load.

  • **kwargs – Additional keyword arguments.

Returns

The loaded EasyDeL model.

classmethod get_torch_loader()[source]#
hf_torch_auto_loader: Optional[Any] = None#
push_to_hub(repo_id: str, use_temp_dir: Optional[bool] = None, commit_message: Optional[str] = None, private: Optional[bool] = None, token: Optional[Union[bool, str]] = None, create_pr: bool = False, gather_fns: Optional[dict[Callable]] = None, float_dtype: Optional[dtype] = None, verbose: bool = True, mismatch_allowed: bool = True, revision: Optional[str] = None, commit_description: Optional[str] = None) str[source]#

Pushes the model to the Hugging Face Hub.

Parameters
  • repo_id (str) – The repository ID on Hugging Face Hub.

  • params (any) – Model parameters.

  • use_temp_dir (bool, optional) – If True, uses a temporary directory. Defaults to None

  • commit_message (str, optional) – The commit message for the push.

  • private (bool, optional) – If True, creates a private repository.

  • token (str or bool, optional) – The Hugging Face Hub token.

  • create_pr (bool, optional) – If True, creates a pull request.

  • gather_fns (dict[Callable], optional) – Custom gather functions for checkpoint saving.

  • float_dtype (dtype, optional) – Data type for saving weights.

  • verbose (bool, optional) – Whether to print verbose messages. Defaults to True.

  • mismatch_allowed (bool, optional) – If True, allows mismatch in parameters while loading. Defaults to True.

  • revision (str, optional) – The revision to push to.

  • commit_description (str, optional) – The commit description for the push.

Returns

The URL of the created repository.

Return type

str

save_pretrained(save_directory: Union[str, PathLike], push_to_hub: bool = False, token: Optional[Union[bool, str]] = None, gather_fns: Optional[dict[Callable]] = None, float_dtype=None, verbose: bool = True, mismatch_allowed: bool = True, enable: Optional[bool] = None, **kwargs)[source]#

Saves the model, its configuration, and optionally pushes it to the Hugging Face Hub.

Parameters
  • save_directory (str or PathLike) – The directory where to save the model.

  • push_to_hub (bool, optional) – If True, pushes the model to the Hugging Face Hub.

  • token (str or bool, optional) – The Hugging Face Hub token.

  • gather_fns (dict[Callable], optional) – Custom gather functions for checkpoint saving.

  • float_dtype (dtype, optional) – Data type for saving weights.

  • verbose (bool, optional) – Whether to print verbose messages. Defaults to True.

  • mismatch_allowed (bool, optional) – If True, allows mismatch in parameters while loading. Defaults to True. enable (bool): if True, allows file to be saved (used for multi-host saving models).

  • **kwargs – Additional keyword arguments for Hugging Face Hub.