easydel.modules.auto.auto_modeling#
- class easydel.modules.auto.auto_modeling.AutoEasyDeLModelForCausalLM[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- None#
Examples
>>> import jax >>> from easydel import AutoEasyDeLModelForCausalLM
>>> # Load a GPT-2 model on a single CPU >>> model = AutoEasyDeLModelForCausalLM.from_pretrained( >>> "gpt2", device=jax.devices("cpu")[0] >>> )
>>> # Load a GPT-2 model sharded across 8 GPUs with data parallelism (DP) and fully sharded data parallelism (FSDP) >>> model = AutoEasyDeLModelForCausalLM.from_pretrained( ... "gpt2", ... sharding_axis_dims=(1, 8, 1, 1), ... sharding_axis_names=("dp", "fsdp", "tp", "sp"), ... device=jax.devices("cpu")[0], # offload to CPU [OPTIONAL] ... from_torch=True, >>> ) ```
- class easydel.modules.auto.auto_modeling.AutoEasyDeLModelForImageTextToText[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.modules.auto.auto_modeling.AutoEasyDeLModelForSeq2SeqLM[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.modules.auto.auto_modeling.AutoEasyDeLModelForSequenceClassification[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.modules.auto.auto_modeling.AutoEasyDeLModelForSpeechSeq2Seq[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- None#
Examples
>>> import jax >>> from easydel import AutoEasyDeLModelForSpeechSeq2Seq
>>> # Load a openai/whisper-large-v3-turbo sharded >>> model = AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained( ... "openai/whisper-large-v3-turbo", ... auto_shard_model=True, >>> )
>>> # Load a openai/whisper-large-v3-turbo model sharded across 8 GPUs with data parallelism (DP) and fully sharded data parallelism (FSDP) >>> model = AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained( ... "openai/whisper-large-v3-turbo", ... sharding_axis_dims=(1, 8, 1, 1), ... sharding_axis_names=("dp", "fsdp", "tp", "sp"), ... device=jax.devices("cpu")[0], # offload to CPU [OPTIONAL] ... from_torch=True, >>> ) ```
- class easydel.modules.auto.auto_modeling.AutoEasyDeLModelForZeroShotImageClassification[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.modules.auto.auto_modeling.AutoStateForCausalLM[source]#
Bases:
BaseAutoEasyState
- class easydel.modules.auto.auto_modeling.AutoStateForImageSequenceClassification[source]#
Bases:
BaseAutoEasyState
- class easydel.modules.auto.auto_modeling.AutoStateForImageTextToText[source]#
Bases:
BaseAutoEasyState
- class easydel.modules.auto.auto_modeling.AutoStateForSeq2SeqLM[source]#
Bases:
BaseAutoEasyState
- class easydel.modules.auto.auto_modeling.AutoStateForSpeechSeq2Seq[source]#
Bases:
BaseAutoEasyState
- class easydel.modules.auto.auto_modeling.AutoStateForZeroShotImageClassification[source]#
Bases:
BaseAutoEasyState
- class easydel.modules.auto.auto_modeling.BaseAutoEasyModel[source]#
Bases:
object- classmethod from_pretrained(pretrained_model_name_or_path: str, device: ~typing.Optional[~jaxlib.xla_extension.Device] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~typing.Optional[~jax._src.lax.lax.Precision] = None, sharding_axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, sharding_axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: ~typing.Optional[~eformer.escale.partition.constraints.PartitionAxis] = None, shard_attention_computation: bool = True, shard_fns: ~typing.Optional[~typing.Union[~typing.Mapping[tuple, ~typing.Callable], dict]] = None, backend: ~typing.Optional[~easydel.infra.etils.EasyDeLBackends] = None, platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = None, config_kwargs: ~typing.Optional[~easydel.infra.base_config.EasyDeLBaseConfigDict] = None, auto_shard_model: bool = False, partition_rules: ~typing.Optional[~typing.Tuple[~typing.Tuple[str, ~jax._src.partition_spec.PartitionSpec], ...]] = None, quantization_method: ~typing.Optional[~easydel.infra.etils.EasyDeLQuantizationMethods] = None, quantization_block_size: int = 128, quantization_pattern: ~typing.Optional[str] = None, quantize_tensors: bool = True, from_torch: ~typing.Optional[bool] = None, **kwargs) EasyDeLBaseModule[source]#
Loads and shards a pretrained model from the Hugging Face Hub and converts it into an EasyDeL compatible model.
- Parameters
pretrained_model_name_or_path (str) โ Path or name of the pretrained model in the Hugging Face Hub.
device (jax.Device, optional) โ Device to load the model on. Defaults to the first CPU.
dtype (jax.numpy.dtype, optional) โ Data type of the model. Defaults to jax.numpy.float32.
param_dtype (jax.numpy.dtype, optional) โ Data type of the model parameters. Defaults to jax.numpy.float32.
precision (jax.lax.Precision, optional) โ Precision for computations. Defaults to jax.lax.Precision(โfastestโ).
sharding_axis_dims (tp.Sequence[int], optional) โ Dimensions of each sharding axis. Defaults to (1, -1, 1, 1).
sharding_axis_names (tp.Sequence[str], optional) โ Names of the sharding axes. Defaults to (โdpโ, โfsdpโ, โtpโ, โspโ).
partition_axis (PartitionAxis) โ PartitionAxis is new module used for partitioning arrays in easydel.
shard_attention_computation (bool, optional) โ Whether to shard attention computation. Defaults to True.
shard_fns (tp.Optional[tp.Mapping[tuple, tp.Callable] | dict], optional) โ Sharding functions to use for the model. If None, auto-sharding is used if auto_shard_model is True. Defaults to None.
platform (tp.Optional[EasyDeLPlatforms], optional) โ platform to use for the model. Defaults to None. backend (tp.Optional[EasyDeLBackends], optional): backend to use for the model. Defaults to None.
config_kwargs (tp.Optional[tp.Mapping[str, tp.Any] | EasyDeLBaseConfigDict], optional) โ Configuration keyword arguments to pass to the model config. Defaults to None.
auto_shard_model (bool, optional) โ Whether to automatically shard the model parameters. Defaults to False.
partition_rules (tp.Optional[tp.Tuple[tp.Tuple[str, PartitionSpec]]], optional) โ Custom partition rules for parameter sharding. If not None, shard_fns should also be provided. Defaults to None.
quantization_method (EasyDeLQuantizationMethods, optional) โ quantization_method to be used to quantize model weights. Defaults to None.
quantization_block_size (int) โ block size to be used for quantizing arrays (only for NF4).
bit_targeted_params (tp.Optional[tp.List[str]], optional) โ tp.List of parameter names to convert to 8-bit precision. If None and 8bit is True, all kernels and embeddings are converted to 8-bit. Defaults to None.
from_torch (bool) โ whenever to load the model from transformers-pytorch.
**kwargs โ Additional keyword arguments to pass to the model and config classes.
- Returns
- A tuple containing the EasyDeL model and the loaded and sharded
model parameters.
- Return type
tp.Tuple[EasyDeLBaseModule, dict]
- class easydel.modules.auto.auto_modeling.BaseAutoEasyState[source]#
Bases:
object- classmethod from_pretrained(pretrained_model_name_or_path: str, device: ~typing.Optional[~jaxlib.xla_extension.Device] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~typing.Optional[~jax._src.lax.lax.Precision] = None, sharding_axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, sharding_axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: ~typing.Optional[~eformer.escale.partition.constraints.PartitionAxis] = None, shard_attention_computation: bool = True, shard_fns: ~typing.Optional[~typing.Union[~typing.Mapping[tuple, ~typing.Callable], dict]] = None, backend: ~typing.Optional[~easydel.infra.etils.EasyDeLBackends] = None, platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = None, config_kwargs: ~typing.Optional[~easydel.infra.base_config.EasyDeLBaseConfigDict] = None, auto_shard_model: bool = False, partition_rules: ~typing.Optional[~typing.Tuple[~typing.Tuple[str, ~jax._src.partition_spec.PartitionSpec], ...]] = None, quantization_method: ~typing.Optional[~easydel.infra.etils.EasyDeLQuantizationMethods] = None, quantization_block_size: int = 128, from_torch: ~typing.Optional[bool] = None, **kwargs) EasyDeLState[source]#
Loads and shards a pretrained model from the Hugging Face Hub and converts it into an EasyDeL compatible state.
- Parameters
pretrained_model_name_or_path (str) โ Path or name of the pretrained model in the Hugging Face Hub.
device (jax.Device, optional) โ Device to load the model on. Defaults to the first CPU.
dtype (jax.numpy.dtype, optional) โ Data type of the model. Defaults to jax.numpy.float32.
param_dtype (jax.numpy.dtype, optional) โ Data type of the model parameters. Defaults to jax.numpy.float32.
precision (jax.lax.Precision, optional) โ Precision for computations. Defaults to jax.lax.Precision(โfastestโ).
sharding_axis_dims (tp.Sequence[int], optional) โ Dimensions of each sharding axis. Defaults to (1, -1, 1, 1).
sharding_axis_names (tp.Sequence[str], optional) โ Names of the sharding axes. Defaults to (โdpโ, โfsdpโ, โtpโ, โspโ).
partition_axis (PartitionAxis) โ PartitionAxis is new module used for partitioning arrays in easydel.
shard_attention_computation (bool, optional) โ Whether to shard attention computation. Defaults to True.
shard_fns (tp.Optional[tp.Mapping[tuple, tp.Callable] | dict], optional) โ Sharding functions to use for the model. If None, auto-sharding is used if auto_shard_model is True. Defaults to None.
backend (tp.Optional[str], optional) โ Backend to use for the model. Defaults to None.
config_kwargs (tp.Optional[tp.Mapping[str, tp.Any]], optional) โ Configuration keyword arguments to pass to the model config. Defaults to None.
auto_shard_model (bool, optional) โ Whether to automatically shard the model parameters. Defaults to False.
partition_rules (tp.Optional[tp.Tuple[tp.Tuple[str, PartitionSpec]]], optional) โ Custom partition rules for parameter sharding. If not None, shard_fns should also be provided. Defaults to None.
quantization_method (EasyDeLQuantizationMethods, optional) โ quantization_method to be used to quantize model weights. Defaults to None.
bit_targeted_params (tp.Optional[tp.List[str]], optional) โ tp.List of parameter names to convert to 8-bit precision. If None and 8bit is True, all kernels and embeddings are converted to 8-bit. Defaults to None.
verbose_params (bool) โ whenever to log number of parameters in converting state.
safe (bool) โ whenever to use safetensors to load engine or parameters (requires engine or parameters to be saved with safe=True while saving them)
from_torch (bool) โ whenever to load the model from transformers-pytorch.
**kwargs โ Additional keyword arguments to pass to the model and config classes.
- Returns
containing the EasyDeL state and the loaded and sharded model parameters.
- Return type