easydel.modules.auto.__init__#

class easydel.modules.auto.__init__.AutoEasyDeLConfig[source]#

Bases: object

classmethod from_pretrained(pretrained_model_name_or_path: str, sharding_axis_dims: Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: Optional[PartitionAxis] = None, shard_attention_computation: bool = True, backend: Optional[EasyDeLBackends] = None, platform: Optional[EasyDeLPlatforms] = None, model_task: TaskType = TaskType.CAUSAL_LM, from_torch: bool = False, **kwargs) EasyDeLBaseConfig[source]#

The from_pretrained function is a helper function that allows you to instantiate a model from the pretrained model repository. It takes as input the name of the model (e.g., ‘bert-base-uncased’) and returns an instance of the class corresponding to your model, with all weights loaded from disk.

Parameters
  • cls – Create an instance of the class that called this function.

  • pretrained_model_name_or_path – str: Identify the model in the huggingface model hub.

  • sharding_axis_dims – tp.Sequence[int]: Specify the dimension of each axis in the sharded model_tasking arrays in easydel.

  • shard_attention_computation – bool: whenever to use shard_map for attention.

  • backend – tp.Optional[EasyDeLBackends] : backend to use for model. model_task (TaskType): Task type of model load and find.

  • from_torch – should config be loaded from torch models or not.

  • **kwargs – Pass additional arguments to the model and config classes.

generation process

Returns

A Model Config

class easydel.modules.auto.__init__.AutoEasyDeLModel[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'base-module'#
class easydel.modules.auto.__init__.AutoEasyDeLModelForCausalLM[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

None#

Examples

>>> import jax
>>> from easydel import AutoEasyDeLModelForCausalLM
>>> # Load a GPT-2 model on a single CPU
>>> model = AutoEasyDeLModelForCausalLM.from_pretrained(
>>>   "gpt2", device=jax.devices("cpu")[0]
>>> )
>>> # Load a GPT-2 model sharded across 8 GPUs with data parallelism (DP) and fully sharded data parallelism (FSDP)
>>> model = AutoEasyDeLModelForCausalLM.from_pretrained(
...  "gpt2",
...  sharding_axis_dims=(1, 8, 1, 1),
...  sharding_axis_names=("dp", "fsdp", "tp", "sp"),
...  device=jax.devices("cpu")[0],  # offload to CPU [OPTIONAL]
...  from_torch=True,
>>> )
```
model_task: TaskType = 'causal-language-model'#
class easydel.modules.auto.__init__.AutoEasyDeLModelForImageTextToText[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'image-text-to-text'#
class easydel.modules.auto.__init__.AutoEasyDeLModelForSeq2SeqLM[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'sequence-to-sequence'#
class easydel.modules.auto.__init__.AutoEasyDeLModelForSequenceClassification[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'sequence-classification'#
class easydel.modules.auto.__init__.AutoEasyDeLModelForSpeechSeq2Seq[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

None#

Examples

>>> import jax
>>> from easydel import AutoEasyDeLModelForSpeechSeq2Seq
>>> # Load a openai/whisper-large-v3-turbo sharded
>>> model = AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained(
...  "openai/whisper-large-v3-turbo",
...  auto_shard_model=True,
>>> )
>>> # Load a openai/whisper-large-v3-turbo model sharded across 8 GPUs with data parallelism (DP) and fully sharded data parallelism (FSDP)
>>> model = AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained(
...  "openai/whisper-large-v3-turbo",
...  sharding_axis_dims=(1, 8, 1, 1),
...  sharding_axis_names=("dp", "fsdp", "tp", "sp"),
...  device=jax.devices("cpu")[0],  # offload to CPU [OPTIONAL]
...  from_torch=True,
>>> )
```
model_task: TaskType = 'speech-sequence-to-sequence'#
class easydel.modules.auto.__init__.AutoEasyDeLModelForZeroShotImageClassification[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'zero-shot-image-classification'#
class easydel.modules.auto.__init__.AutoEasyDeLVisionModel[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'vision-module'#
class easydel.modules.auto.__init__.AutoShardAndGatherFunctions[source]#

Bases: object

A class to automatically generate shard and gather functions for a given model configuration.

This class provides two methods to generate shard and gather functions:

  • from_config: Generates functions based on a provided EasyDeLBaseConfig object.

  • from_pretrained: Generates functions based on a pretrained model name or path.

None#
from_config()[source]#

Generates shard and gather functions based on a provided EasyDeLBaseConfig object.

from_pretrained()[source]#

Generates functions based on a pretrained model name or path.

classmethod from_config(config: EasyDeLBaseConfig, partition_rules: Optional[Tuple[Tuple[str, PartitionSpec]]] = None, flatten: bool = True, model_task: TaskType = TaskType.CAUSAL_LM, depth_target: Optional[List[str]] = None)[source]#

Generates shard and gather functions based on a provided EasyDeLBaseConfig object.

Parameters
  • config – An EasyDeLBaseConfig object containing the model configuration.

  • partition_rules – A tuple of tuples containing partition rule names and PartitionSpec objects. If None, uses the default partition rules from the config.

  • flatten – Whether to flatten the shard and gather functions. Defaults to True. model_task (TaskType): Task type of model load and find.

  • depth_target – Pad the sharding to depth, for example make {params:tensor} with depth_target = [“row”] to {row:{params:tensor}}. Defaults to None.

Returns

A tuple containing the shard and gather functions.

static from_params(params, partition_rules, mesh)[source]#
classmethod from_pretrained(pretrained_model_name_or_path: str, sharding_axis_dims: Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: Optional[PartitionAxis] = None, shard_attention_computation: bool = True, backend: Optional[EasyDeLBackends] = None, platform: Optional[EasyDeLPlatforms] = None, partition_rules: Optional[Tuple[Tuple[str, PartitionSpec]]] = None, flatten: bool = True, config_kwargs: Optional[Mapping[str, Any]] = None, model_task: TaskType = TaskType.CAUSAL_LM, from_torch: bool = False, trust_remote_code: bool = False) Tuple[Mapping[str, Callable], Mapping[str, Callable]][source]#

Generates shard and gather functions based on a pretrained model name or path.

Parameters
  • pretrained_model_name_or_path – The name or path of the pretrained model.

  • sharding_axis_dims – The dimensions of the sharding axes. Defaults to (1, -1, 1, 1).

  • sharding_axis_names – The names of the sharding axes. Defaults to (“dp”, “fsdp”, “tp”, “sp”).

  • partition_axis (PartitionAxis) – PartitionAxis is new module used for partitioning arrays in easydel.

  • shard_attention_computation – Whether to shard the attention computation. Defaults to True.

  • backend – The backend to use for custom kernels. Defaults to None.

  • partition_rules – A tuple of tuples containing partition rule names and PartitionSpec objects. If None, uses the default partition rules from the config.

  • flatten – Whether to flatten the shard and gather functions. Defaults to True.

  • config_kwargs – Additional keyword arguments to pass to the AutoEasyDeLConfig constructor. Defaults to None. model_task (TaskType): Task type of model load and find. from_torch: should config be loaded from torch models or not.

  • trust_remote_code (bool) – whenever to trust remote code loaded from HF.

Returns

A tuple containing the shard and gather functions.

class easydel.modules.auto.__init__.AutoState[source]#

Bases: BaseAutoEasyState

class easydel.modules.auto.__init__.AutoStateForCausalLM[source]#

Bases: BaseAutoEasyState

class easydel.modules.auto.__init__.AutoStateForImageSequenceClassification[source]#

Bases: BaseAutoEasyState

class easydel.modules.auto.__init__.AutoStateForImageTextToText[source]#

Bases: BaseAutoEasyState

class easydel.modules.auto.__init__.AutoStateForSeq2SeqLM[source]#

Bases: BaseAutoEasyState

class easydel.modules.auto.__init__.AutoStateForSpeechSeq2Seq[source]#

Bases: BaseAutoEasyState

class easydel.modules.auto.__init__.AutoStateForZeroShotImageClassification[source]#

Bases: BaseAutoEasyState

class easydel.modules.auto.__init__.AutoStateVisionModel[source]#

Bases: BaseAutoEasyState

easydel.modules.auto.__init__.get_modules_by_type(model_type: str, task_type: TaskType) Tuple[Type[EasyDeLBaseConfig], Union[Type[EasyDeLBaseModule], Any]][source]#
The get_modules_by_type function is a helper function that returns the following:
  1. The config class for the model type specified (e.g., LlamaConfig, FalconConfig)

  2. The EasyDeL Model class for the model type specified (e.g., FlaxLlamaForCausalLM, FalconForCausalLM)