easydel.modules.auto.auto_configuration#
- class easydel.modules.auto.auto_configuration.AutoEasyDeLConfig[source]#
Bases:
objectFactory helpers to load EasyDeL configs from identifiers or checkpoints.
- classmethod from_pretrained(pretrained_model_name_or_path: str, sharding_axis_dims: Sequence[int] = (1, -1, 1, 1, 1), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'ep', 'tp', 'sp'), partition_axis: eformer.escale.partition.manager.PartitionAxis | None = None, backend: easydel.infra.etils.EasyDeLBackends | None = None, platform: easydel.infra.etils.EasyDeLPlatforms | None = None, model_task: TaskType = TaskType.CAUSAL_LM, from_torch: bool = False, **kwargs) EasyDeLBaseConfig[source]#
The from_pretrained function is a helper function that allows you to instantiate a model from the pretrained model repository. It takes as input the name of the model (e.g., ‘bert-base-uncased’) and returns an instance of the class corresponding to your model, with all weights loaded from disk.
- Parameters
cls – Create an instance of the class that called this function.
pretrained_model_name_or_path – str: Identify the model in the huggingface model hub.
sharding_axis_dims – tp.Sequence[int]: Specify the dimension of each axis in the sharded model_tasking arrays in easydel.
backend – tp.Optional[EasyDeLBackends] : backend to use for model. model_task (TaskType): Task type of model load and find.
from_torch – should config be loaded from torch models or not.
**kwargs – Pass additional arguments to the model and config classes.
generation process
- Returns
A Model Config
- class easydel.modules.auto.auto_configuration.AutoShardAndGatherFunctions[source]#
Bases:
objectA class to automatically generate shard and gather functions for a given model configuration.
This class provides two methods to generate shard and gather functions:
from_config: Generates functions based on a provided EasyDeLBaseConfig object.
from_pretrained: Generates functions based on a pretrained model name or path.
- None#
- from_config()[source]#
Generates shard and gather functions based on a provided EasyDeLBaseConfig object.
- classmethod from_config(config: EasyDeLBaseConfig, partition_rules: tuple[tuple[str, jax.sharding.PartitionSpec]] | None = None, flatten: bool = True, model_task: TaskType = TaskType.CAUSAL_LM, depth_target: list[str] | None = None)[source]#
Generates shard and gather functions based on a provided EasyDeLBaseConfig object.
- Parameters
config – An EasyDeLBaseConfig object containing the model configuration.
partition_rules – A tuple of tuples containing partition rule names and PartitionSpec objects. If None, uses the default partition rules from the config.
flatten – Whether to flatten the shard and gather functions. Defaults to True. model_task (TaskType): Task type of model load and find.
depth_target – Pad the sharding to depth, for example make {params:tensor} with depth_target = [“row”] to {row:{params:tensor}}. Defaults to None.
- Returns
A tuple containing the shard and gather functions.
- static from_params(params, partition_rules, mesh)[source]#
Generates shard and gather functions directly from model parameters, partition rules, and a mesh.
- Parameters
params – The model parameters (pytree) to generate functions for.
partition_rules – A tuple of tuples defining the partitioning strategy.
mesh – The JAX device mesh to use for sharding.
- Returns
A tuple containing the shard and gather functions.
- classmethod from_pretrained(pretrained_model_name_or_path: str, sharding_axis_dims: Sequence[int] = (1, -1, 1, 1, 1), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'ep', 'tp', 'sp'), partition_axis: eformer.escale.partition.manager.PartitionAxis | None = None, backend: easydel.infra.etils.EasyDeLBackends | None = None, platform: easydel.infra.etils.EasyDeLPlatforms | None = None, partition_rules: tuple[tuple[str, jax.sharding.PartitionSpec]] | None = None, flatten: bool = True, config_kwargs: Optional[Mapping[str, Any]] = None, model_task: TaskType = TaskType.CAUSAL_LM, from_torch: bool = False, trust_remote_code: bool = False) tuple[Mapping[str, Callable], Mapping[str, Callable]][source]#
Generates shard and gather functions based on a pretrained model name or path.
- Parameters
pretrained_model_name_or_path – The name or path of the pretrained model.
sharding_axis_dims – The dimensions of the sharding axes. Defaults to (1, -1, 1, 1, 1).
sharding_axis_names – The names of the sharding axes. Defaults to (“dp”, “fsdp”, “ep”, “tp”, “sp”).
partition_axis (PartitionAxis) – PartitionAxis is new module used for partitioning arrays in easydel.
backend – The backend to use for custom kernels. Defaults to None.
partition_rules – A tuple of tuples containing partition rule names and PartitionSpec objects. If None, uses the default partition rules from the config.
flatten – Whether to flatten the shard and gather functions. Defaults to True.
config_kwargs – Additional keyword arguments to pass to the AutoEasyDeLConfig constructor. Defaults to None. model_task (TaskType): Task type of model load and find. from_torch: should config be loaded from torch models or not.
trust_remote_code (bool) – whenever to trust remote code loaded from HF.
- Returns
A tuple containing the shard and gather functions.
- easydel.modules.auto.auto_configuration.get_modules_by_type(model_type: str, task_type: TaskType) tuple[type[easydel.infra.base_config.EasyDeLBaseConfig], type[easydel.infra.base_module.EasyDeLBaseModule] | Any][source]#
- The get_modules_by_type function is a helper function that returns the following:
The config class for the model type specified (e.g., LlamaConfig, FalconConfig)
The EasyDeL Model class for the model type specified (e.g., FlaxLlamaForCausalLM, FalconForCausalLM)
- easydel.modules.auto.auto_configuration.is_flatten(pytree: dict)[source]#
- The is_flatten function checks if the pytree is flattened.
If it is, then the first key in the dictionary will be a tuple of (mpl, mpl_id). Otherwise, it will be an integer representing mpl_id.
- Parameters
pytree – dict: Pass the pytree to the function
- Returns
True if the pytree is a flattened tree, and false otherwise