easydel.modules.auto.auto_modeling#
- class easydel.modules.auto.auto_modeling.AutoEasyDeLModel[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.modules.auto.auto_modeling.AutoEasyDeLModelForCausalLM[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- None#
Examples
>>> import jax >>> from easydel import AutoEasyDeLModelForCausalLM
>>> # Load a GPT-2 model on a single CPU >>> model = AutoEasyDeLModelForCausalLM.from_pretrained( >>> "gpt2", device=jax.devices("cpu")[0] >>> )
>>> # Load a GPT-2 model sharded across 8 GPUs with data parallelism (DP) and fully sharded data parallelism (FSDP) >>> model = AutoEasyDeLModelForCausalLM.from_pretrained( ... "gpt2", ... sharding_axis_dims=(1, 8, 1, 1), ... sharding_axis_names=("dp", "fsdp", "tp", "sp"), ... device=jax.devices("cpu")[0], # offload to CPU [OPTIONAL] ... from_torch=True, >>> ) ```
- class easydel.modules.auto.auto_modeling.AutoEasyDeLModelForImageTextToText[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.modules.auto.auto_modeling.AutoEasyDeLModelForSeq2SeqLM[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.modules.auto.auto_modeling.AutoEasyDeLModelForSequenceClassification[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.modules.auto.auto_modeling.AutoEasyDeLModelForSpeechSeq2Seq[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- None#
Examples
>>> import jax >>> from easydel import AutoEasyDeLModelForSpeechSeq2Seq
>>> # Load a openai/whisper-large-v3-turbo sharded >>> model = AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained( ... "openai/whisper-large-v3-turbo", ... auto_shard_model=True, >>> )
>>> # Load a openai/whisper-large-v3-turbo model sharded across 8 GPUs with data parallelism (DP) and fully sharded data parallelism (FSDP) >>> model = AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained( ... "openai/whisper-large-v3-turbo", ... sharding_axis_dims=(1, 8, 1, 1), ... sharding_axis_names=("dp", "fsdp", "tp", "sp"), ... device=jax.devices("cpu")[0], # offload to CPU [OPTIONAL] ... from_torch=True, >>> ) ```
- class easydel.modules.auto.auto_modeling.AutoEasyDeLModelForZeroShotImageClassification[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.modules.auto.auto_modeling.AutoEasyDeLVisionModel[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.modules.auto.auto_modeling.AutoState[source]#
Bases:
BaseAutoEasyState
- class easydel.modules.auto.auto_modeling.AutoStateForCausalLM[source]#
Bases:
BaseAutoEasyState
- class easydel.modules.auto.auto_modeling.AutoStateForImageSequenceClassification[source]#
Bases:
BaseAutoEasyState
- class easydel.modules.auto.auto_modeling.AutoStateForImageTextToText[source]#
Bases:
BaseAutoEasyState
- class easydel.modules.auto.auto_modeling.AutoStateForSeq2SeqLM[source]#
Bases:
BaseAutoEasyState
- class easydel.modules.auto.auto_modeling.AutoStateForSpeechSeq2Seq[source]#
Bases:
BaseAutoEasyState
- class easydel.modules.auto.auto_modeling.AutoStateForZeroShotImageClassification[source]#
Bases:
BaseAutoEasyState
- class easydel.modules.auto.auto_modeling.AutoStateVisionModel[source]#
Bases:
BaseAutoEasyState
- class easydel.modules.auto.auto_modeling.BaseAutoEasyModel[source]#
Bases:
objectBase class for all Auto EasyDeL model classes. Provides common class methods for loading models from configurations or pretrained checkpoints.
- classmethod from_config(config: ~easydel.infra.base_config.EasyDeLBaseConfig, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~typing.Optional[~jax._src.lax.lax.Precision] = None, *, rngs: ~typing.Optional[~flax.nnx.rnglib.Rngs] = None) EasyDeLBaseModule[source]#
Instantiates a model module directly from a configuration object.
- Parameters
config (EasyDeLBaseConfig) โ The configuration object for the model.
dtype (jnp.dtype) โ Data type for computation. Defaults to jnp.float32.
param_dtype (jnp.dtype) โ Data type for parameters. Defaults to jnp.float32.
precision (Optional[jax.lax.Precision]) โ JAX precision level. Defaults to None.
rngs (Optional[flax.nnx.Rngs]) โ Random number generators. Defaults to Rngs(42).
- Returns
An instance of the specific EasyDeL model module.
- Return type
- classmethod from_pretrained(pretrained_model_name_or_path: str, device: ~typing.Optional[~jaxlib.xla_extension.Device] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~typing.Optional[~jax._src.lax.lax.Precision] = None, sharding_axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, sharding_axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: ~typing.Optional[~eformer.escale.partition.manager.PartitionAxis] = None, shard_attention_computation: bool = True, shard_fns: ~typing.Optional[~typing.Union[~typing.Mapping[tuple, ~typing.Callable], dict]] = None, backend: ~typing.Optional[~easydel.infra.etils.EasyDeLBackends] = None, platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = None, config_kwargs: ~typing.Optional[~easydel.infra.base_config.EasyDeLBaseConfigDict] = None, auto_shard_model: bool = False, partition_rules: ~typing.Optional[~typing.Tuple[~typing.Tuple[str, ~jax._src.partition_spec.PartitionSpec], ...]] = None, quantization_platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = None, quantization_method: ~typing.Optional[~easydel.infra.etils.EasyDeLQuantizationMethods] = None, quantization_block_size: int = 128, quantization_pattern: ~typing.Optional[str] = None, quantize_tensors: bool = True, verbose: bool = True, from_torch: ~typing.Optional[bool] = None, **kwargs) EasyDeLBaseModule[source]#
Loads and shards a pretrained model from the Hugging Face Hub and converts it into an EasyDeL compatible model.
- Parameters
pretrained_model_name_or_path (str) โ Path or name of the pretrained model in the Hugging Face Hub.
device (jax.Device, optional) โ Device to load the model on. Defaults to the first CPU.
dtype (jnp.dtype, optional) โ Data type of the model. Defaults to jnp.float32.
param_dtype (jnp.dtype, optional) โ Data type of the model parameters. Defaults to jnp.float32.
precision (jax.lax.Precision, optional) โ Precision for computations. Defaults to jax.lax.Precision(โfastestโ).
sharding_axis_dims (tp.Sequence[int], optional) โ Dimensions of each sharding axis. Defaults to (1, -1, 1, 1).
sharding_axis_names (tp.Sequence[str], optional) โ Names of the sharding axes. Defaults to (โdpโ, โfsdpโ, โtpโ, โspโ).
partition_axis (PartitionAxis) โ PartitionAxis is new module used for partitioning arrays in easydel.
shard_attention_computation (bool, optional) โ Whether to shard attention computation. Defaults to True.
shard_fns (tp.Optional[tp.Mapping[tuple, tp.Callable] | dict], optional) โ Sharding functions to use for the model. If None, auto-sharding is used if auto_shard_model is True. Defaults to None.
platform (tp.Optional[EasyDeLPlatforms], optional) โ platform to use for the model. Defaults to None. backend (tp.Optional[EasyDeLBackends], optional): backend to use for the model. Defaults to None.
config_kwargs (tp.Optional[tp.Mapping[str, tp.Any] | EasyDeLBaseConfigDict], optional) โ Configuration keyword arguments to pass to the model config. Defaults to None.
auto_shard_model (bool, optional) โ Whether to automatically shard the model parameters. Defaults to False.
partition_rules (tp.Optional[tp.Tuple[tp.Tuple[str, PartitionSpec]]], optional) โ Custom partition rules for parameter sharding. If not None, shard_fns should also be provided. Defaults to None.
quantization_method (EasyDeLQuantizationMethods, optional) โ quantization_method to be used to quantize model weights. Defaults to None.
quantization_block_size (int) โ block size to be used for quantizing arrays (only for NF4).
bit_targeted_params (tp.Optional[tp.List[str]], optional) โ tp.List of parameter names to convert to 8-bit precision. If None and 8bit is True, all kernels and embeddings are converted to 8-bit. Defaults to None.
from_torch (bool) โ whenever to load the model from transformers-pytorch.
**kwargs โ Additional keyword arguments to pass to the model and config classes.
- Returns
- A tuple containing the EasyDeL model and the loaded and sharded
model parameters.
- Return type
tp.Tuple[EasyDeLBaseModule, dict]
- class easydel.modules.auto.auto_modeling.BaseAutoEasyState[source]#
Bases:
objectBase class for Auto EasyDeL state classes. Provides common class methods for creating model states from configurations or pretrained checkpoints.
- _base#
The corresponding Auto EasyDeL model class.
- Type
- classmethod from_config(config: ~easydel.infra.base_config.EasyDeLBaseConfig, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~typing.Optional[~jax._src.lax.lax.Precision] = None, *, rngs: ~typing.Optional[~flax.nnx.rnglib.Rngs] = None) EasyDeLState[source]#
Creates an EasyDeLState directly from a configuration object.
- Parameters
config (EasyDeLBaseConfig) โ The configuration object for the model.
dtype (jnp.dtype) โ Data type for computation. Defaults to jnp.float32.
param_dtype (jnp.dtype) โ Data type for parameters. Defaults to jnp.float32.
precision (Optional[jax.lax.Precision]) โ JAX precision level. Defaults to None.
rngs (Optional[flax.nnx.Rngs]) โ Random number generators. Defaults to Rngs(42).
- Returns
An initialized EasyDeLState for the model.
- Return type
- classmethod from_pretrained(pretrained_model_name_or_path: str, device: ~typing.Optional[~jaxlib.xla_extension.Device] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~typing.Optional[~jax._src.lax.lax.Precision] = None, sharding_axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, sharding_axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: ~typing.Optional[~eformer.escale.partition.manager.PartitionAxis] = None, shard_attention_computation: bool = True, shard_fns: ~typing.Optional[~typing.Union[~typing.Mapping[tuple, ~typing.Callable], dict]] = None, backend: ~typing.Optional[~easydel.infra.etils.EasyDeLBackends] = None, platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = None, config_kwargs: ~typing.Optional[~easydel.infra.base_config.EasyDeLBaseConfigDict] = None, auto_shard_model: bool = False, partition_rules: ~typing.Optional[~typing.Tuple[~typing.Tuple[str, ~jax._src.partition_spec.PartitionSpec], ...]] = None, quantization_method: ~typing.Optional[~easydel.infra.etils.EasyDeLQuantizationMethods] = None, quantization_block_size: int = 128, from_torch: ~typing.Optional[bool] = None, **kwargs) EasyDeLState[source]#
Loads and shards a pretrained model from the Hugging Face Hub and converts it into an EasyDeL compatible state.
- Parameters
pretrained_model_name_or_path (str) โ Path or name of the pretrained model in the Hugging Face Hub.
device (jax.Device, optional) โ Device to load the model on. Defaults to the first CPU.
dtype (jnp.dtype, optional) โ Data type of the model. Defaults to jnp.float32.
param_dtype (jnp.dtype, optional) โ Data type of the model parameters. Defaults to jnp.float32.
precision (jax.lax.Precision, optional) โ Precision for computations. Defaults to jax.lax.Precision(โfastestโ).
sharding_axis_dims (tp.Sequence[int], optional) โ Dimensions of each sharding axis. Defaults to (1, -1, 1, 1).
sharding_axis_names (tp.Sequence[str], optional) โ Names of the sharding axes. Defaults to (โdpโ, โfsdpโ, โtpโ, โspโ).
partition_axis (PartitionAxis) โ PartitionAxis is new module used for partitioning arrays in easydel.
shard_attention_computation (bool, optional) โ Whether to shard attention computation. Defaults to True.
shard_fns (tp.Optional[tp.Mapping[tuple, tp.Callable] | dict], optional) โ Sharding functions to use for the model. If None, auto-sharding is used if auto_shard_model is True. Defaults to None.
backend (tp.Optional[str], optional) โ Backend to use for the model. Defaults to None.
config_kwargs (tp.Optional[tp.Mapping[str, tp.Any]], optional) โ Configuration keyword arguments to pass to the model config. Defaults to None.
auto_shard_model (bool, optional) โ Whether to automatically shard the model parameters. Defaults to False.
partition_rules (tp.Optional[tp.Tuple[tp.Tuple[str, PartitionSpec]]], optional) โ Custom partition rules for parameter sharding. If not None, shard_fns should also be provided. Defaults to None.
quantization_method (EasyDeLQuantizationMethods, optional) โ quantization_method to be used to quantize model weights. Defaults to None.
bit_targeted_params (tp.Optional[tp.List[str]], optional) โ tp.List of parameter names to convert to 8-bit precision. If None and 8bit is True, all kernels and embeddings are converted to 8-bit. Defaults to None.
verbose_params (bool) โ whenever to log number of parameters in converting state.
safe (bool) โ whenever to use safetensors to load engine or parameters (requires engine or parameters to be saved with safe=True while saving them)
from_torch (bool) โ whenever to load the model from transformers-pytorch.
**kwargs โ Additional keyword arguments to pass to the model and config classes.
- Returns
containing the EasyDeL state and the loaded and sharded model parameters.
- Return type