easydel.modules.auto.auto_configuration#

class easydel.modules.auto.auto_configuration.AutoEasyDeLConfig[source]#

Bases: object

classmethod from_pretrained(pretrained_model_name_or_path: str, sharding_axis_dims: Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: Optional[PartitionAxis] = None, shard_attention_computation: bool = True, backend: Optional[EasyDeLBackends] = None, platform: Optional[EasyDeLPlatforms] = None, model_task: TaskType = TaskType.CAUSAL_LM, from_torch: bool = False, **kwargs) EasyDeLBaseConfig[source]#

The from_pretrained function is a helper function that allows you to instantiate a model from the pretrained model repository. It takes as input the name of the model (e.g., โ€˜bert-base-uncasedโ€™) and returns an instance of the class corresponding to your model, with all weights loaded from disk.

Parameters
  • cls โ€“ Create an instance of the class that called this function.

  • pretrained_model_name_or_path โ€“ str: Identify the model in the huggingface model hub.

  • sharding_axis_dims โ€“ tp.Sequence[int]: Specify the dimension of each axis in the sharded model_tasking arrays in easydel.

  • shard_attention_computation โ€“ bool: whenever to use shard_map for attention.

  • backend โ€“ tp.Optional[EasyDeLBackends] : backend to use for model. model_task (TaskType): Task type of model load and find.

  • from_torch โ€“ should config be loaded from torch models or not.

  • **kwargs โ€“ Pass additional arguments to the model and config classes.

generation process

Returns

A Model Config

class easydel.modules.auto.auto_configuration.AutoShardAndGatherFunctions[source]#

Bases: object

A class to automatically generate shard and gather functions for a given model configuration.

This class provides two methods to generate shard and gather functions:

  • from_config: Generates functions based on a provided EasyDeLBaseConfig object.

  • from_pretrained: Generates functions based on a pretrained model name or path.

None#
from_config()[source]#

Generates shard and gather functions based on a provided EasyDeLBaseConfig object.

from_pretrained()[source]#

Generates functions based on a pretrained model name or path.

classmethod from_config(config: EasyDeLBaseConfig, partition_rules: Optional[Tuple[Tuple[str, PartitionSpec]]] = None, flatten: bool = True, model_task: TaskType = TaskType.CAUSAL_LM, depth_target: Optional[List[str]] = None)[source]#

Generates shard and gather functions based on a provided EasyDeLBaseConfig object.

Parameters
  • config โ€“ An EasyDeLBaseConfig object containing the model configuration.

  • partition_rules โ€“ A tuple of tuples containing partition rule names and PartitionSpec objects. If None, uses the default partition rules from the config.

  • flatten โ€“ Whether to flatten the shard and gather functions. Defaults to True. model_task (TaskType): Task type of model load and find.

  • depth_target โ€“ Pad the sharding to depth, for example make {params:tensor} with depth_target = [โ€œrowโ€] to {row:{params:tensor}}. Defaults to None.

Returns

A tuple containing the shard and gather functions.

static from_params(params, partition_rules, mesh)[source]#
classmethod from_pretrained(pretrained_model_name_or_path: str, sharding_axis_dims: Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: Optional[PartitionAxis] = None, shard_attention_computation: bool = True, backend: Optional[EasyDeLBackends] = None, platform: Optional[EasyDeLPlatforms] = None, partition_rules: Optional[Tuple[Tuple[str, PartitionSpec]]] = None, flatten: bool = True, config_kwargs: Optional[Mapping[str, Any]] = None, model_task: TaskType = TaskType.CAUSAL_LM, from_torch: bool = False, trust_remote_code: bool = False) Tuple[Mapping[str, Callable], Mapping[str, Callable]][source]#

Generates shard and gather functions based on a pretrained model name or path.

Parameters
  • pretrained_model_name_or_path โ€“ The name or path of the pretrained model.

  • sharding_axis_dims โ€“ The dimensions of the sharding axes. Defaults to (1, -1, 1, 1).

  • sharding_axis_names โ€“ The names of the sharding axes. Defaults to (โ€œdpโ€, โ€œfsdpโ€, โ€œtpโ€, โ€œspโ€).

  • partition_axis (PartitionAxis) โ€“ PartitionAxis is new module used for partitioning arrays in easydel.

  • shard_attention_computation โ€“ Whether to shard the attention computation. Defaults to True.

  • backend โ€“ The backend to use for custom kernels. Defaults to None.

  • partition_rules โ€“ A tuple of tuples containing partition rule names and PartitionSpec objects. If None, uses the default partition rules from the config.

  • flatten โ€“ Whether to flatten the shard and gather functions. Defaults to True.

  • config_kwargs โ€“ Additional keyword arguments to pass to the AutoEasyDeLConfig constructor. Defaults to None. model_task (TaskType): Task type of model load and find. from_torch: should config be loaded from torch models or not.

  • trust_remote_code (bool) โ€“ whenever to trust remote code loaded from HF.

Returns

A tuple containing the shard and gather functions.

easydel.modules.auto.auto_configuration.get_modules_by_type(model_type: str, task_type: TaskType) Tuple[Type[EasyDeLBaseConfig], Union[Type[EasyDeLBaseModule], Any]][source]#
The get_modules_by_type function is a helper function that returns the following:
  1. The config class for the model type specified (e.g., LlamaConfig, FalconConfig)

  2. The EasyDeL Model class for the model type specified (e.g., FlaxLlamaForCausalLM, FalconForCausalLM)

easydel.modules.auto.auto_configuration.is_flatten(pytree: dict)[source]#
The is_flatten function checks if the pytree is flattened.

If it is, then the first key in the dictionary will be a tuple of (mpl, mpl_id). Otherwise, it will be an integer representing mpl_id.

Parameters

pytree โ€“ dict: Pass the pytree to the function

Returns

True if the pytree is a flattened tree, and false otherwise