easydel.modules.auto.auto_configuration#
- class easydel.modules.auto.auto_configuration.AutoEasyDeLConfig[source]#
Bases:
object- classmethod from_pretrained(pretrained_model_name_or_path: str, sharding_axis_dims: Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: Optional[PartitionAxis] = None, shard_attention_computation: bool = True, backend: Optional[EasyDeLBackends] = None, platform: Optional[EasyDeLPlatforms] = None, model_task: TaskType = TaskType.CAUSAL_LM, from_torch: bool = False, **kwargs) EasyDeLBaseConfig[source]#
The from_pretrained function is a helper function that allows you to instantiate a model from the pretrained model repository. It takes as input the name of the model (e.g., โbert-base-uncasedโ) and returns an instance of the class corresponding to your model, with all weights loaded from disk.
- Parameters
cls โ Create an instance of the class that called this function.
pretrained_model_name_or_path โ str: Identify the model in the huggingface model hub.
sharding_axis_dims โ tp.Sequence[int]: Specify the dimension of each axis in the sharded model_tasking arrays in easydel.
shard_attention_computation โ bool: whenever to use shard_map for attention.
backend โ tp.Optional[EasyDeLBackends] : backend to use for model. model_task (TaskType): Task type of model load and find.
from_torch โ should config be loaded from torch models or not.
**kwargs โ Pass additional arguments to the model and config classes.
generation process
- Returns
A Model Config
- class easydel.modules.auto.auto_configuration.AutoShardAndGatherFunctions[source]#
Bases:
objectA class to automatically generate shard and gather functions for a given model configuration.
This class provides two methods to generate shard and gather functions:
from_config: Generates functions based on a provided EasyDeLBaseConfig object.
from_pretrained: Generates functions based on a pretrained model name or path.
- None#
- from_config()[source]#
Generates shard and gather functions based on a provided EasyDeLBaseConfig object.
- classmethod from_config(config: EasyDeLBaseConfig, partition_rules: Optional[Tuple[Tuple[str, PartitionSpec]]] = None, flatten: bool = True, model_task: TaskType = TaskType.CAUSAL_LM, depth_target: Optional[List[str]] = None)[source]#
Generates shard and gather functions based on a provided EasyDeLBaseConfig object.
- Parameters
config โ An EasyDeLBaseConfig object containing the model configuration.
partition_rules โ A tuple of tuples containing partition rule names and PartitionSpec objects. If None, uses the default partition rules from the config.
flatten โ Whether to flatten the shard and gather functions. Defaults to True. model_task (TaskType): Task type of model load and find.
depth_target โ Pad the sharding to depth, for example make {params:tensor} with depth_target = [โrowโ] to {row:{params:tensor}}. Defaults to None.
- Returns
A tuple containing the shard and gather functions.
- classmethod from_pretrained(pretrained_model_name_or_path: str, sharding_axis_dims: Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: Optional[PartitionAxis] = None, shard_attention_computation: bool = True, backend: Optional[EasyDeLBackends] = None, platform: Optional[EasyDeLPlatforms] = None, partition_rules: Optional[Tuple[Tuple[str, PartitionSpec]]] = None, flatten: bool = True, config_kwargs: Optional[Mapping[str, Any]] = None, model_task: TaskType = TaskType.CAUSAL_LM, from_torch: bool = False, trust_remote_code: bool = False) Tuple[Mapping[str, Callable], Mapping[str, Callable]][source]#
Generates shard and gather functions based on a pretrained model name or path.
- Parameters
pretrained_model_name_or_path โ The name or path of the pretrained model.
sharding_axis_dims โ The dimensions of the sharding axes. Defaults to (1, -1, 1, 1).
sharding_axis_names โ The names of the sharding axes. Defaults to (โdpโ, โfsdpโ, โtpโ, โspโ).
partition_axis (PartitionAxis) โ PartitionAxis is new module used for partitioning arrays in easydel.
shard_attention_computation โ Whether to shard the attention computation. Defaults to True.
backend โ The backend to use for custom kernels. Defaults to None.
partition_rules โ A tuple of tuples containing partition rule names and PartitionSpec objects. If None, uses the default partition rules from the config.
flatten โ Whether to flatten the shard and gather functions. Defaults to True.
config_kwargs โ Additional keyword arguments to pass to the AutoEasyDeLConfig constructor. Defaults to None. model_task (TaskType): Task type of model load and find. from_torch: should config be loaded from torch models or not.
trust_remote_code (bool) โ whenever to trust remote code loaded from HF.
- Returns
A tuple containing the shard and gather functions.
- easydel.modules.auto.auto_configuration.get_modules_by_type(model_type: str, task_type: TaskType) Tuple[Type[EasyDeLBaseConfig], Union[Type[EasyDeLBaseModule], Any], functools.partial | Any][source]#
- The get_modules_by_type function is a helper function that returns the following:
The config class for the model type specified (e.g., LlamaConfig, FalconConfig)
The Flax Model class for the model type specified (e.g., FlaxLlamaForCausalLM, FalconForCausalLM)
A function to convert a HuggingFace pretrained checkpoint into an easydel checkpoint
- easydel.modules.auto.auto_configuration.is_flatten(pytree: dict)[source]#
- The is_flatten function checks if the pytree is flattened.
If it is, then the first key in the dictionary will be a tuple of (mpl, mpl_id). Otherwise, it will be an integer representing mpl_id.
- Parameters
pytree โ dict: Pass the pytree to the function
- Returns
True if the pytree is a flattened tree, and false otherwise