easydel.infra.__init__#
- class easydel.infra.__init__.EasyDeLBaseConfig(axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), attn_mechanism: ~typing.Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa', 'autoregressive_decodeattn'] = 'vanilla', decode_attn_mechanism: ~typing.Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa', 'autoregressive_decodeattn'] = None, blocksize_k: int = 128, blocksize_q: int = 128, blocksize_b: int = 1, partition_axis: ~eformer.escale.partition.manager.PartitionAxis = PartitionAxis(data_parallel_axis='dp', fully_sharded_data_parallel_axis='fsdp', tensor_parallel_axis='tp', sequence_parallel_axis='sp', expert_parallel_axis='ep', batch_axis=('fsdp', 'dp'), sequence_axis='sp', query_sequence_axis='sp', head_axis='tp', kv_head_axis=None, key_sequence_axis='sp', hidden_state_axis='tp', mlp_intermediate_axis='tp', vocab_axis='tp', expert_axis='ep', expert_gate_axis=None, attention_dim_axis=None, attention_kv_dim_axis=None, bias_head_sequence_axis=None, bias_key_sequence_axis=None, decode_batch_axis=('fsdp', 'dp'), decode_query_sequence_axis=None, decode_head_axis='tp', decode_kv_head_axis=None, decode_key_sequence_axis='sp', decode_attention_dim_axis=None, decode_attention_kv_dim_axis=None), shard_attention_computation: bool = True, use_sharded_kv_caching: bool = False, use_sharding_constraint: bool = False, backend: ~typing.Optional[~easydel.infra.etils.EasyDeLBackends] = None, platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = None, easy_method: ~typing.Literal['train', 'serve', 'convert'] = 'train', bits: ~typing.Optional[int] = None, scan_ring_attention: bool = True, scan_attention_layers: bool = False, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, sequence_axis_name: str = 'sp', gradient_checkpointing: ~easydel.infra.etils.EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, precompute_masks: bool = True, kv_cache_quantization_method: ~easydel.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, kv_cache_quantization_blocksize: int = 64, quantization_method: ~easydel.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, quantization_pattern: str = '.*', quantization_blocksize: int = 64, kv_cache_sharding_sequence_axis_name: ~typing.Union[str, ~typing.Tuple[str, ...]] = 'sp', flash_attention_backward_pass_impl: ~typing.Literal['triton', 'xla'] = 'triton', attn_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, attn_softmax_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, fcm_max_ratio: float = 0.0, fcm_min_ratio: float = 0.0, hardware_abstraction: bool = False, pallas_m_block_size: int = 128, pallas_k_block_size: int = 128, pallas_n_block_size: int = 128, **kwargs)[source]#
Bases:
PretrainedConfigInitialize the configuration for EasyDeL. :param axis_dims: Dimensions of the axes. Default is (1, -1, 1, 1). :type axis_dims: tp.Sequence[int] :param axis_names: Names of the axes. Default is (“dp”, “fsdp”, “tp”, “sp”). :type axis_names: tp.Sequence[str] :param attn_mechanism: Attention mechanism to use. Default is DEFAULT_ATTENTION_MECHANISM.
decode_attn_mechanism (AVAILABLE_ATTENTION_MECHANISMS): Attention mechanism to use for decode phase. Default is None.
- Parameters
blocksize_k (int) – Block size for key. Default is 128.
blocksize_q (int) – Block size for query. Default is 128.
blocksize_b (int) – Block size for batch. Default is 1.
partition_axis (PartitionAxis) – Partition axis configuration. Default is PartitionAxis().
shard_attention_computation (bool) – Whether to shard attention computation. Default is True.
use_sharded_kv_caching (bool) – Whether to use sharded key-value caching. Default is False.
use_sharding_constraint (bool) – Whether to use sharding constraint. Default is False.
backend (tp.Optional[EasyDeLBackends]) – Backend to use. Default is None.
platform (tp.Optional[EasyDeLPlatforms]) – Platform to use. Default is None.
easy_method (tp.Literal["train", "serve", "convert"]) – Method to use. Default is EasyMethod.TRAIN.
bits (tp.Optional[int]) – Number of bits for quantization. Default is None.
scan_ring_attention (bool) – Whether to scan ring attention. Default is True.
scan_attention_layers (bool) – Whether to scan attention layers. Default is False.
use_scan_mlp (bool) – Whether to use scan MLP. Default is False.
scan_mlp_chunk_size (int) – Chunk size for scan MLP. Default is 1024.
sequence_axis_name (str) – Name of the attention axis. Default is “sp”.
gradient_checkpointing (EasyDeLGradientCheckPointers) – Gradient checkpointing method. Default is EasyDeLGradientCheckPointers.NONE.
kv_cache_quantization_method (EasyDeLQuantizationMethods) – Key-value cache quantization method. Default is EasyDeLQuantizationMethods.NONE.
kv_cache_quantization_blocksize (int) – Block size for key-value cache quantization. Default is 64.
quantization_method (EasyDeLQuantizationMethods) – Quantization method. Default is EasyDeLQuantizationMethods.NONE.
quantization_pattern (str) – Pattern for quantization. Default is “.*”.
quantization_blocksize (int) – Block size for quantization. Default is 64.
kv_cache_sharding_sequence_axis_name (tp.Union[str, tp.Tuple[str, ...]]) – Name of the key-value cache sharding sequence axis. Default is “sp”.
flash_attention_backward_pass_impl (tp.Literal["triton", "xla"]) – Implementation for flash attention backward pass. Default is “triton”.
attn_dtype (jnp.dtype) – Data type for attention. Default is device half.
attn_softmax_dtype (jnp.dtype) – Data type for softmax ops in attention. Default is jnp.float32.
fcm_max_ratio (float) – Maximum ratio for FCM. Default is 0.0.
fcm_min_ratio (float) – Minimum ratio for FCM. Default is 0.0.
hardware_abstraction (bool) – Whether to use hardware abstraction. Default is DEFAULT_HARDWARE_ABSTRACTION.
pallas_m_block_size (int) – Block size for Pallas M. Default is DEFAULT_PALLAS_M_BLOCK_SIZE.
pallas_k_block_size (int) – Block size for Pallas K. Default is DEFAULT_PALLAS_K_BLOCK_SIZE.
pallas_n_block_size (int) – Block size for Pallas N. Default is DEFAULT_PALLAS_N_BLOCK_SIZE.
**kwargs – Additional keyword arguments.
- Raises
Warning – If kv_cache_quantization_method is not NONE and use_sharded_kv_caching is True.
- add_basic_configurations(axis_dims: ~typing.Sequence[int] = <eformer.common_types._Empty object>, dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = <eformer.common_types._Empty object>, axis_names: ~typing.Sequence[str] = <eformer.common_types._Empty object>, attn_mechanism: ~typing.Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa', 'autoregressive_decodeattn'] = <eformer.common_types._Empty object>, decode_attn_mechanism: ~typing.Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa', 'autoregressive_decodeattn'] = <eformer.common_types._Empty object>, blocksize_k: int = <eformer.common_types._Empty object>, blocksize_q: int = <eformer.common_types._Empty object>, blocksize_b: int = <eformer.common_types._Empty object>, partition_axis: ~eformer.escale.partition.manager.PartitionAxis = <eformer.common_types._Empty object>, shard_attention_computation: bool = <eformer.common_types._Empty object>, use_sharded_kv_caching: bool = <eformer.common_types._Empty object>, backend: ~typing.Optional[~easydel.infra.etils.EasyDeLBackends] = <eformer.common_types._Empty object>, platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = <eformer.common_types._Empty object>, easy_method: ~typing.Literal['train', 'serve', 'convert'] = <eformer.common_types._Empty object>, bits: ~typing.Optional[int] = <eformer.common_types._Empty object>, scan_ring_attention: bool = <eformer.common_types._Empty object>, scan_attention_layers: bool = <eformer.common_types._Empty object>, use_sharding_constraint: bool = <eformer.common_types._Empty object>, use_scan_mlp: bool = <eformer.common_types._Empty object>, scan_mlp_chunk_size: int = <eformer.common_types._Empty object>, sequence_axis_name: str = <eformer.common_types._Empty object>, gradient_checkpointing: ~easydel.infra.etils.EasyDeLGradientCheckPointers = <eformer.common_types._Empty object>, precompute_masks: bool = <eformer.common_types._Empty object>, kv_cache_quantization_method: ~easydel.infra.etils.EasyDeLQuantizationMethods = <eformer.common_types._Empty object>, kv_cache_quantization_blocksize: int = <eformer.common_types._Empty object>, quantization_method: ~easydel.infra.etils.EasyDeLQuantizationMethods = <eformer.common_types._Empty object>, quantization_blocksize: int = <eformer.common_types._Empty object>, quantization_pattern: str = <eformer.common_types._Empty object>, kv_cache_sharding_sequence_axis_name: ~typing.Union[str, ~typing.Tuple[str, ...]] = <eformer.common_types._Empty object>, flash_attention_backward_pass_impl: ~typing.Literal['triton', 'xla'] = <eformer.common_types._Empty object>, attn_dtype: ~numpy.dtype = <eformer.common_types._Empty object>, attn_softmax_dtype: ~numpy.dtype = <eformer.common_types._Empty object>, hardware_abstraction: bool = <eformer.common_types._Empty object>, pallas_m_block_size: int = <eformer.common_types._Empty object>, pallas_k_block_size: int = <eformer.common_types._Empty object>, pallas_n_block_size: int = <eformer.common_types._Empty object>, **kwargs)[source]#
It initializes all the attributes of an object, and it’s called when you create a new instance of that class.
- Parameters
axis_dims (tp.Sequence[int], optional) – Specify the number of dimensions for each axis. Defaults to (1, -1, 1, 1).
axis_names (tp.Sequence[str], optional) – Set the names of the axes. Defaults to (“dp”, “fsdp”, “tp”, “sp”).
attn_mechanism (AVAILABLE_ATTENTION_MECHANISMS, optional) – attention mechanism to use. Defaults to DEFAULT_ATTENTION_MECHANISM. decode_attn_mechanism (AVAILABLE_ATTENTION_MECHANISMS): Attention mechanism to use for decode phase. Default is None.
blocksize_k (int, optional) – block size of key_states. Defaults to 128.
blocksize_q (int, optional) – block size of query_states. Defaults to 128.
blocksize_b (int, optional) – block size of bias. Defaults to 1.
partition_axis (PartitionAxis, optional) – PartitionAxis is new module used for partitioning arrays in easydel. Defaults to PartitionAxis().
shard_attention_computation (bool, optional) – whenever to use shard_map for attention. Defaults to True.
use_sharded_kv_caching (bool, optional) – whenever to use shard_map and sharding for key and value. Defaults to True.
backend (tp.Optional[EasyDeLBackends], optional) – Specify the backend to use. Defaults to None.
platform (tp.Optional[EasyDeLPlatforms], optional) – Specify the platform to used to use. Defaults to None.
easy_method (tp.Literal["train", "serve", "convert"], optional) – easydel Quantization Method to be applied for. Defaults to EasyMethod.TRAIN.
bits (tp.Optional[int], optional) – Model bits for quantization. Defaults to None.
scan_ring_attention (bool, optional) – Whether to use can for ring attention. Defaults to True.
scan_attention_layers (bool, optional) – Whether to use can for attention layers. Defaults to False.
use_sharding_constraint (bool, optional) – whether to use sharding constraint for the arrays. Defaults to False.
use_scan_mlp (bool, optional) – Determine whether to use scan_mlp or not. Defaults to False.
scan_mlp_chunk_size (int, optional) – Size of chunks in scan MLP. Defaults to 1024.
sequence_axis_name (str, optional) – Name of the attention axis name. Defaults to “sp”.
gradient_checkpointing (EasyDeLQuantizationMethods, optional) – Gradient Checkpointing method for created or loaded module (applied on mlp and attn layers most of the times).
kv_cache_quantization_method (EasyDeLQuantizationMethods, optional) – key and value quantization type. Defaults to EasyDeLQuantizationMethods.NONE.
kv_cache_quantization_blocksize (int, optional) – size of kv cache quantization. Defaults to 64.
quantization_method (EasyDeLQuantizationMethods, optional) – linear modules quantization type. Defaults to EasyDeLQuantizationMethods.NONE.
quantization_blocksize (int, optional) – size of linear quantization. Defaults to 64.
quantization_pattern (str) – re pattern to be used for quantizing layers.
kv_cache_sharding_sequence_axis_name (tp.Union[str, tp.Tuple[str, ...]], optional) – axis name to target for sharding sequences. Defaults to “sp”.
flash_attention_backward_pass_impl (tp.Literal["triton", "xla"], optional) – Specify the backward pass kernel for flash attention. Defaults to “triton”.
attn_dtype (jnp.dtype, optional) – Data type for attention computations. Defaults to device half.
attn_softmax_dtype (jnp.dtype, optional) – Data type for softmax in attention op computations. Defaults to jnp.float32.
fcm_max_ratio (float, optional) – Maximum ratio for flash cross attention. Defaults to 0.0.
fcm_min_ratio (float, optional) – Minimum ratio for flash cross attention. Defaults to 0.0.
hardware_abstraction (bool, optional) – whenever to switch to custom pallas kernels instead of JAX. Defaults to DEFAULT_HARDWARE_ABSTRACTION.
pallas_m_block_size (int, optional) – block size m dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_M_BLOCK_SIZE.
pallas_k_block_size (int, optional) – block size k dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_K_BLOCK_SIZE.
pallas_n_block_size (int, optional) – block size n dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_N_BLOCK_SIZE.
- static create_mesh(axis_dims: Sequence[int] = (1, -1, 1, 1), axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), dcn_axis_dims: Optional[Sequence[int]] = None, process_is_granule: bool = False, should_sort_granules_by_key: bool = True, allow_split_physical_axes: bool = True, backend: Optional[str] = None)[source]#
The create_mesh function creates a mesh object that can be used to shard arrays.
- Returns
A mesh object
- classmethod from_pretrained(pretrained_model_name_or_path: Union[str, PathLike], cache_dir: Optional[Union[str, PathLike]] = None, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = 'main', **kwargs) PretrainedConfig[source]#
Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.
- Parameters
pretrained_model_name_or_path (str or os.PathLike) –
This can be either:
a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.
a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.
a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.
cache_dir (str or os.PathLike, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.
force_download (bool, optional, defaults to False) – Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.
resume_download – Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
proxies (Dict[str, str], optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.
token (str or bool, optional) – The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
revision (str, optional, defaults to “main”) –
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.
<Tip>
To test a pull request you made on the Hub, you can pass `revision=”refs/pr/<pr_number>”.
</Tip>
return_unused_kwargs (bool, optional, defaults to False) –
If False, then this function returns just the final configuration object.
If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.
subfolder (str, optional, defaults to “”) – In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.
kwargs (Dict[str, tp.Any], optional) – The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.
- Returns
The configuration object instantiated from this pretrained model.
- Return type
[PretrainedConfig]
Examples:
>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a >>> # derived class: BertConfig >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased" >>> ) # Download configuration from huggingface.co and cache. >>> config = BertConfig.from_pretrained( ... "./test/saved_model/" >>> ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')* >>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json") >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", output_attentions=True, foo=False >>> ) >>> assert config.output_attentions == True >>> config, unused_kwargs = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", ... output_attentions=True, ... foo=False, ... return_unused_kwargs=True, >>> ) >>> assert config.output_attentions == True >>> assert unused_kwargs == {"foo": False}
- get_axis_dims() Sequence[int][source]#
The get_axis_dims function returns a sequence of integers representing the dimensions of each axis.
- Parameters
self – Represent the instance of the class
- Returns
The dimensions of the axes
- get_axis_names() Sequence[str][source]#
The get_axis_names function returns a list of the names of the axes.
- Parameters
self – Represent the instance of the class
- Returns
A list of the names of all axes
- get_backend() str[source]#
The get_backend function returns the backend that is currently being used. If no backend has been set, it will return the default JAX backend.
- Parameters
self – Bind the method to an object
- Returns
The backend platform
- get_basic_frequencies(head_size: Optional[int] = None, rotary_dim: Optional[int] = None, base: Optional[float] = None) Any[source]#
Get basic frequencies for rotary embeddings.
- Parameters
head_size – Size of attention heads (defaults to self.head_dim)
rotary_dim – Dimension for rotary embeddings (defaults to head_size)
base – Base value for frequency computation (defaults to self.rope_theta)
- Returns
ModuleCaches instance containing computed frequencies
- get_basic_inv_frequencies(head_size: Optional[int] = None, rotary_dim: Optional[int] = None, base: Optional[float] = None, partial_rotary_factor: float = 1.0) Any[source]#
Get basic inv frequencies for rotary embeddings.
- Parameters
head_size – Size of attention heads (defaults to self.head_dim)
rotary_dim – Dimension for rotary embeddings (defaults to head_size)
base – Base value for frequency computation (defaults to self.rope_theta)
- Returns
ModuleCaches instance containing computed frequencies
- get_basic_rope(dtype: Union[Array, ndarray, bool, number], head_size: int, rotary_dim: Optional[int] = None, is_neox_style: bool = True, base: Optional[float] = None)[source]#
Get basic rotary position embeddings.
- Parameters
dtype – Data type for the embeddings
head_size – Size of attention heads
rotary_dim – Dimension for rotary embeddings (defaults to head_size)
is_neox_style – Whether to use NeoX style embeddings
base – Base value for frequency computation (defaults to self.rope_theta)
- Returns
Rotary position embeddings func
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
- property granted_mask_max_position_embedding: int#
- property mesh#
The mesh property is a helper property that creates a Mesh object from the axis_dims and axis_names attributes of an object, which are assumed to be lists of integers and strings, respectively. The platform attribute is also used if it exists.
- Parameters
self – Refer to the object itself
- Returns
A jaxMesh
- property partition_manager: PartitionManager#
- read_basics_from_config(config: EasyDeLBaseConfig)[source]#
- class easydel.infra.__init__.EasyDeLBaseModule(*args: Any, **kwargs: Any)[source]#
Bases:
Module,BaseModuleProtocol,EasyBridgeMixin,EasyGenerationMixinBase class for EasyDeL modules, providing common functionalities for model initialization, parameter handling, and integration with the EasyDeL ecosystem.
- apply_lora_to_layers(lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = False, rngs: Optional[Rngs] = None) SELF[source]#
Applies Low-Rank Adaptation (LoRA) layers to the specified linear layers within the module.
Replaces targeted flax.linen.Dense layers with easydel.layers.lora.LoraLinear layers, initializing the LoRA matrices (A and B).
- Parameters
lora_rank (int) – The rank of the LoRA decomposition.
lora_pattern (tp.Optional[str], optional) – A regular expression to match the names of the Dense layers to apply LoRA to. If None, applies to common attention and MLP layers. Defaults to None.
verbose (bool, optional) – If True, prints information about which layers are being modified. Defaults to False.
rngs (tp.Optional[nn.Rngs], optional) – JAX random number generators for initializing LoRA matrices. If None, default RNGs might be used. Defaults to None.
- Returns
The module instance with LoRA layers applied.
- Return type
SELF
- property causal_mask: Array#
Retrieves or computes the basic causal attention mask from the configuration.
Uses self.config.get_basic_causal_mask() and caches the result.
- Returns
The causal attention mask, potentially cached.
- Return type
jnp.ndarray
- compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#
Computes the loss for the model given a batch of inputs and labels.
This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.
- Parameters
labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.
loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.
loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.
**batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).
- Returns
- A tuple containing:
The model’s output ( Pytree typically including logits, hidden states etc.)
A LossMetrics object containing the calculated loss and potentially other metrics.
- Return type
tp.Tuple[tp.Any, LossMetrics]
- Raises
AssertionError – If labels are required for the loss function but are not provided or inferred.
AssertionError – If sequence classification loss is used without num_labels in the config.
- float(change_runtime_dtype: bool = True) SELF[source]#
Converts the module’s parameters to single-precision (float32).
Optionally also changes the runtime computation dtype (self.dtype) to float32.
- Parameters
change_runtime_dtype (bool) – If True, also sets self.dtype to jnp.float32. Defaults to True.
- Returns
The module instance with parameters (and potentially runtime dtype) set to float32.
- Return type
SELF
- property frequencies: Array#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- fully_gather() SELF[source]#
Applies JAX sharding constraints to gather all parameters onto the host or a single device.
This function marks all parameters to have no sharding (PartitionSpec()). It uses jax.jit with out_shardings to enforce these gathering constraints.
- Returns
The model instance with gathering constraints applied.
- Return type
SELF
- fully_shard(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None) SELF[source]#
Applies JAX sharding constraints to all parameters based on the partition rules.
This function ensures that parameters are explicitly marked with their intended sharding, which can be useful for performance and correctness checks. It uses jax.jit with out_shardings to enforce the constraints.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses config rules. Defaults to None.
- Returns
The model instance with sharding constraints applied.
- Return type
SELF
- gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#
Gathers the model’s parameters from potentially distributed devices to the host or a single device.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules used to determine how parameters were originally sharded. If None, uses config rules. Defaults to None.
mesh (tp.Optional[Mesh], optional) – JAX device mesh from which to gather. If None, uses config mesh. Defaults to None.
overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional) – Additional functions to apply, potentially overriding default gathering for specific parameters. Defaults to None.
- Returns
The model instance with gathered parameters.
- Return type
SELF
- get_static_arguments() Tuple[source]#
Returns a tuple of static arguments required by the module’s __call__ method.
Static arguments are those that don’t change across calls and can be potentially cached or handled differently by JIT compilation. This base implementation returns an empty tuple. Subclasses should override this if they have static arguments.
- Returns
A tuple containing static arguments.
- Return type
tp.Tuple
- property graphdef: Union[NodeDef[Node], NodeRef[Node]]#
Returns the graph definition (structure without parameters) of the module.
Uses flax.nnx.split to separate the graph definition from the state (parameters).
- Returns
The graph definition of the module.
- Return type
nn.GraphDef
- property graphother: State[Key, VariableState[Any]]#
Returns any other state variables in the module (non-parameters).
Uses flax.nnx.split to separate non-parameter state variables.
- Returns
The graph state containing non-parameter variables.
- Return type
nn.GraphState
- property graphstate: State[Key, VariableState[Any]]#
Returns the graph state (parameters) of the module.
Uses flax.nnx.split to separate the state (parameters) from the graph definition.
- Returns
The graph state containing the module’s parameters.
- Return type
nn.GraphState
- property graphtree_params_shape: Dict#
Computes and returns the shapes of the module’s parameters as a nested dictionary.
It uses nnx.eval_shape to determine the shapes without actual computation, then extracts the shape information from the resulting graph state.
- Returns
A nested dictionary mirroring the parameter structure, containing their shapes.
- Return type
tp.Dict
- property graphtree_shape: Dict#
Computes and returns the shapes of all state variables (including non-parameters) in the module.
Uses nnx.eval_shape on the entire module state (parameters and others) and extracts the shape information.
- Returns
A nested dictionary mirroring the module’s state structure, containing the shapes.
- Return type
tp.Dict
- half(change_runtime_dtype: bool = True) SELF[source]#
Converts the module’s parameters to half-precision (float16).
Optionally also changes the runtime computation dtype (self.dtype) to float16.
- Parameters
change_runtime_dtype (bool) – If True, also sets self.dtype to jnp.float16. Defaults to True.
- Returns
The module instance with parameters (and potentially runtime dtype) set to float16.
- Return type
SELF
- property inv_frequencies: Array#
Retrieves or computes the inv-frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_inv_frequencies() and caches the result.
- Returns
The inv-frequency components, potentially cached.
- Return type
jnp.ndarray
- classmethod lazy_init(*args, **kwargs) SELF[source]#
Performs a “lazy” initialization using nnx.eval_shape.
This initializes the module structure and determines parameter shapes without actually allocating memory for the parameters. Useful for inspecting the model structure or preparing for sharding.
- Parameters
*args – Positional arguments passed to the class constructor.
**kwargs – Keyword arguments passed to the class constructor.
- Returns
A module instance with initialized structure but potentially abstract parameters.
- Return type
SELF
- property loss_function#
Determines and returns the appropriate loss function based on the configuration or model type.
It prioritizes config.loss_type, then self.loss_type, and finally tries to infer the loss type from the class name. If no suitable loss function is found, it defaults to ForCausalLMLoss and issues a warning.
- Returns
The selected loss function (e.g., ForCausalLMLoss, ForSequenceClassificationLoss).
- Return type
tp.Callable
- merge_lora_params(pytree: Dict) SELF[source]#
Merges LoRA parameters from a pytree into the base model’s parameters.
- Parameters
pytree (tp.Dict) – A dictionary (pytree) containing the LoRA parameters (A and B matrices) structured similarly to the base model’s parameters.
- Returns
The module instance with LoRA parameters merged into the base weights.
- Return type
SELF
- static merge_module(graphdef: Union[NodeDef[Node], NodeRef[Node]], graphstate: State[Key, VariableState[Any]], graphother: State[Key, VariableState[Any]])[source]#
- merge_params(tree)[source]#
Merges a given parameter state tree back into the module.
Reconstructs the module using its existing graph definition and ‘other’ state, but replaces the parameter state with the provided tree.
- Parameters
tree – A pytree (likely a nn.GraphState) containing the parameters to merge.
- Returns
The module instance with the new parameters merged in.
- Return type
- merge_params_dict(params_dict: Dict) SELF[source]#
Merges parameters from a dictionary back into the module’s state.
Updates the module’s current parameter state with values from the provided dictionary.
- Parameters
params_dict (tp.Dict) – A nested dictionary containing the parameters to merge. The structure should match the module’s parameter structure.
- Returns
The module instance with the parameters from the dictionary merged in.
- Return type
SELF
- Raises
KeyError – If a key from params_dict is not found in the module’s current state.
- property mesh: Mesh#
Retrieves the JAX device mesh from the module’s configuration.
- Returns
The device mesh defined in self.config.mesh.
- Return type
- property model_task: Optional[str]#
Returns the specific task associated with this model instance (e.g., ‘causal-language-model’).
- Returns
The model task identifier, or None if not set.
- Return type
tp.Optional[str]
- property model_type: Optional[str]#
Returns the specific type of this model instance (e.g., ‘llama’, ‘mistral’).
- Returns
The model type identifier, or None if not set.
- Return type
tp.Optional[str]
- property module_dtype: dtype#
Determines the data type of the module’s parameters.
It inspects the flattened parameter state to find the dtype of the first parameter encountered.
- Returns
The data type of the module’s parameters.
- Return type
jnp.dtype
- property parameters: Dict#
Retrieves the parameters of the module as a dictionary.
This property iterates through the module and its submodules, extracting variables marked as nn.Param and returning them in a flat dictionary where keys represent the parameter path.
- Returns
A dictionary containing the module’s parameters.
- Return type
tp.Dict
- property params: Dict#
Returns the parameters and other state variables of the module as a dictionary.
Uses flax.nnx.split to get the combined state (parameters and others).
- Returns
A dictionary containing all state variables of the module.
- Return type
tp.Dict
- property params_sharding: Dict#
Retrieves the sharding annotation for each parameter in the module.
- Returns
- A nested dictionary mirroring the parameter structure, containing the
sharding information (e.g., NamedSharding, PartitionSpec) for each parameter, or None if unsharded.
- Return type
tp.Dict
- prepare_inputs_for_call(**kwargs)[source]#
Prepares keyword arguments before passing them to the module’s __call__ method.
This base implementation simply returns the kwargs as is. Subclasses can override this to modify or add arguments as needed (e.g., for generation).
- Parameters
**kwargs – The keyword arguments intended for __call__.
- Returns
The prepared keyword arguments.
- Return type
dict
- property pure_transform_fn#
Returns a pure transformation function for PyTorch state dicts to EasyDeL parameters.
Similar to transform_fn, but this version does not include sharding functions. It identifies embedding and LayerNorm layers and returns a partial function (torch_dict_to_easydel_params) configured only with layer names and dtype.
- Returns
A partial function for converting a PyTorch state dict without applying sharding.
- Return type
tp.Callable
- quantize(method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT, block_size: int = 128, quantization_pattern: Optional[str] = None, quantize_tensors: bool = True, verbose: Optional[bool] = None) SELF[source]#
Applies quantization to the module’s linear layers or tensors.
- Parameters
method (EasyDeLQuantizationMethods, optional) – The quantization algorithm to use (e.g., A8BIT, NF4). Defaults to EasyDeLQuantizationMethods.A8BIT.
block_size (int, optional) – The block size for quantization methods that support it. Defaults to 128.
quantization_pattern (tp.Optional[str], optional) – A regular expression to match parameter names that should be quantized. If None, uses a default pattern. Defaults to None.
quantize_tensors (bool, optional) – If True, quantizes the tensor values directly. If False (currently default behavior in implementation), replaces Linear layers with their quantized equivalents. Defaults to True (though implementation differs).
verbose (tp.Optional[bool], optional) – If True, logs information during the quantization process. Defaults to True only on process index 0.
- Returns
The quantized model instance.
- Return type
SELF
- shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#
Shards the model’s parameters according to the specified rules and mesh.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses config rules. Defaults to None.
mesh (tp.Optional[Mesh], optional) – JAX device mesh. If None, uses config mesh. Defaults to None.
overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional) – Additional functions to apply, potentially overriding default sharding for specific parameters. Defaults to None.
- Returns
The sharded model instance.
- Return type
SELF
- split_lora_params() Dict[source]#
Splits merged LoRA parameters back out from the base model’s parameters.
This function assumes LoRA parameters were previously merged using merge_lora_params or a similar process that stored the original base weights and LoRA weights appropriately.
- Returns
- A pytree containing the extracted LoRA parameters (A and B matrices).
The base model parameters are restored to their original (pre-merge) state.
- Return type
tp.Dict
- split_params()[source]#
Splits the module and returns the parameter state.
Uses nnx.split to extract the GraphState containing the parameters.
- Returns
The parameter state of the module.
- Return type
nn.GraphState
- split_params_dict(extract_fn: Optional[Callable] = None, remove_none: bool = True) Dict[source]#
Splits the module parameters and returns them as a nested dictionary.
Extracts the parameter state, converts it to a plain dictionary (removing VariableState wrappers), and optionally removes entries with None values.
- Parameters
extract_fn (tp.Optional[tp.Callable], optional) – A function to apply to each parameter during extraction. Defaults to None.
remove_none (bool, optional) – If True, removes key-value pairs where the value is None. Defaults to True.
- Returns
A nested dictionary containing the module’s parameters.
- Return type
tp.Dict
- property static_arguments: Tuple#
Retrieves or computes static arguments needed for the module’s __call__ method.
Uses self.get_static_arguments() and caches the result. Static arguments are typically those that don’t change during execution and can be pre-computed.
- Returns
A tuple of static arguments.
- Return type
tp.Tuple
- to_dtype(dtype: dtype) SELF[source]#
Converts the module’s parameters to the specified data type.
It iterates through the module’s parameters (excluding quantization-related ones) and casts them to the target dtype. It also updates the param_dtype attribute of the module and its submodules if they exist.
- Parameters
dtype (jnp.dtype) – The target data type for the parameters.
- Returns
The module instance with parameters converted to the specified dtype.
- Return type
SELF
- to_state() Any[source]#
Converts the current module instance into an EasyDeLState object.
This is useful for saving and managing the model’s state, including parameters and potentially optimizer state (though optimizer state is typically added later).
- Returns
An EasyDeLState object representing the current model state.
- Return type
- to_torch(**kwargs)[source]#
Converts the EasyDeL module to its equivalent Hugging Face PyTorch model.
Requires the corresponding PyTorch model class to be available and registered. Uses utility functions to transfer parameters from JAX to PyTorch format.
- Parameters
**kwargs – Additional keyword arguments passed to the parameter transformation function.
- Returns
The equivalent Hugging Face PyTorch model with loaded weights.
- Return type
torch.nn.Module
- property transform_fn#
Returns a partial function for transforming PyTorch state dicts to EasyDeL parameters.
This function identifies embedding and LayerNorm layers within the module and creates a transformation function (torch_dict_to_easydel_params) pre-configured with these layer names, the target parameter dtype, and the module’s sharding functions.
- Returns
A partial function ready to convert a PyTorch state dict.
- Return type
tp.Callable
- unwrap_lora_to_layers(verbose: bool = False) SELF[source]#
Reverts the application of LoRA layers, restoring the original linear layers.
Replaces easydel.layers.lora.LoraLinear layers with their original flax.linen.Dense counterparts, discarding the LoRA matrices.
- Parameters
verbose (bool, optional) – If True, prints information about which layers are being reverted. Defaults to False.
- Returns
The module instance with LoRA layers removed and original layers restored.
- Return type
SELF
- class easydel.infra.__init__.EasyDeLState(step: int | jax.Array, graphdef: nn.GraphDef, graphstate: nn.GraphState, graphother: nn.GraphState, tx: optax.GradientTransformation, opt_state: tp.Optional[optax.OptState], apply_fn: tp.Optional[tp.Callable] = None)[source]#
Bases:
PyTreeNodeRepresents the state of an EasyDeL model during training or inference.
This class encapsulates the model’s parameters, optimizer state, training step, and potentially other metadata. It provides methods for applying gradients, managing sharding, saving, and loading the state.
- graphdef#
The definition of the model’s computation graph (structure).
- Type
nn.GraphDef
- graphstate#
The state of the model’s parameters.
- Type
nn.GraphState
- graphother#
The state of non-parameter variables within the model.
- Type
nn.GraphState
- tx#
The optimizer transformation (e.g., AdamW, SGD). Marked as a non-pytree node.
- Type
optax.GradientTransformation
- opt_state#
The state of the optimizer (e.g., moments). Marked as a pytree node.
- Type
tp.Optional[optax.OptState]
- apply_fn#
A function to apply the model (often model.__call__). Typically not directly part of the state but can be associated.
- Type
tp.Optional[tp.Callable]
- apply_fn: tp.Optional[tp.Callable] = None#
- apply_gradients(*, grads)[source]#
Updates the model’s parameters and optimizer state based on calculated gradients.
- Parameters
grads – A pytree matching the structure of self.graphstate containing the gradients.
- Returns
- A new state object with the updated parameters (graphstate),
optimizer state (opt_state), and incremented step count.
- Return type
- Raises
AssertionError – If opt_state or tx is not initialized.
- classmethod create(*, step: tp.Optional[int] = None, graphdef: tp.Optional[nn.GraphDef] = None, graphstate: tp.Optional[nn.GraphState] = None, graphother: tp.Optional[nn.GraphState] = None, model: tp.Optional[nn.Module] = None, tx: tp.Optional[optax.GradientTransformation] = None, opt_state: tp.Optional[optax.OptState] = None, init_opt_state: bool = False) EasyDeLState[source]#
Creates a new EasyDeLState instance.
This class method provides a flexible way to initialize the state, either from an existing nn.Module or by providing the graph components (graphdef, graphstate, graphother) directly. It also handles optimizer state initialization.
- Parameters
step (tp.Optional[int]) – The initial training step. Defaults to 0.
graphdef (tp.Optional[nn.GraphDef]) – The model’s graph definition.
graphstate (tp.Optional[nn.GraphState]) – The model’s parameter state.
graphother (tp.Optional[nn.GraphState]) – The model’s non-parameter state.
model (tp.Optional[nn.Module]) – An EasyDeL module instance. If provided, graphdef, graphstate, and graphother are derived from it. Cannot be provided simultaneously with graph components.
tx (tp.Optional[optax.GradientTransformation]) – The optimizer transformation.
opt_state (tp.Optional[optax.OptState]) – The initial optimizer state. Cannot be provided if init_opt_state is True.
init_opt_state (bool) – If True, initializes the optimizer state using tx.init(graphstate). Requires tx to be provided. Defaults to False.
- Returns
A new instance of the state.
- Return type
- Raises
ValueError – If model and graph components are provided simultaneously.
ValueError – If graph components are provided partially.
ValueError – If init_opt_state is True and opt_state is also provided.
ValueError – If init_opt_state is True but tx is not provided.
- gather_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#
Gathers the model parameters (graphstate and graphother) from distributed devices.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules used for the original sharding. If None, uses model config rules. Defaults to None.
mesh (tp.Optional[Mesh], optional) – The JAX device mesh to gather from. If None, uses model’s mesh. Defaults to None.
- Returns
A new state object with gathered graphstate and graphother.
- Return type
- gather_optimizer_state(partition_rules=None)[source]#
Gathers the optimizer state from potentially distributed devices.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules used to determine how the state was sharded. If None, uses rules from the model’s config. Defaults to None.
- Returns
A new state object with the gathered opt_state.
- Return type
- Raises
AssertionError – If opt_state is not initialized.
- gather_state()[source]#
Gathers the entire state (model parameters and optimizer state) from distributed devices.
This is a convenience method that calls gather_model and gather_optimizer_state.
- Returns
A new state object with both model and optimizer states gathered.
- Return type
- graphdef: nn.GraphDef#
- graphother: nn.GraphState#
- graphstate: nn.GraphState#
- init_tx(tx: GradientTransformation, partition_rules: Any = None) EasyDeLState[source]#
Initializes the optimizer state (opt_state) for the current graphstate using the provided optimizer transformation (tx). It automatically handles sharding based on the model’s partition rules.
- Parameters
tx (optax.GradientTransformation) – The optimizer transformation to initialize with.
partition_rules (PartitionLike, optional) – Partitioning rules for the optimizer state. If None, uses the rules from the associated model’s config. Defaults to None.
- Returns
- A new state object with the initialized and potentially sharded
opt_state and the provided tx.
- Return type
- load_optimizer(load_directory: Union[str, PathLike])[source]#
Loads the optimizer state from saved files.
Reads the optimizer state structure from a pickle file (OPTIMIZER_STRUCT_NAME) and the tensor data from a SafeTensors file (OPTIMIZER_NAME) within the specified directory.
- Parameters
load_directory (tp.Union[str, os.PathLike]) – The directory containing the saved optimizer state files.
- Returns
A new state object with the loaded opt_state.
- Return type
- Raises
FileNotFoundError – If the required optimizer files are not found.
Exception – If any error occurs during loading or deserialization.
- classmethod load_state(load_directory: ~typing.Union[str, ~os.PathLike], device: ~typing.Optional[~jaxlib.xla_extension.Device] = 'cpu', dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~typing.Optional[~jax._src.lax.lax.Precision] = None, sharding_axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, sharding_axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: ~typing.Optional[~eformer.escale.partition.manager.PartitionAxis] = None, shard_attention_computation: bool = True, shard_fns: ~typing.Optional[~typing.Union[~typing.Mapping[tuple, ~typing.Callable], dict]] = None, backend: ~typing.Optional[~typing.Any] = None, platform: ~typing.Optional[~typing.Any] = None, config_kwargs: ~typing.Optional[~typing.Any] = None, model_task: ~easydel.infra.factory.TaskType = TaskType.AUTO_BIND, auto_shard_model: bool = False, partition_rules: ~typing.Optional[~typing.Tuple[~typing.Tuple[str, ~jax._src.partition_spec.PartitionSpec], ...]] = None, quantization_platform: ~typing.Optional[~typing.Any] = None, quantization_method: ~typing.Optional[~typing.Any] = None, quantization_block_size: int = 128, quantization_pattern: ~typing.Optional[str] = None, quantize_tensors: bool = True, verbose: bool = True, **kwargs)[source]#
Loads an EasyDeLState from a saved checkpoint directory.
This class method reconstructs the model configuration, loads the model parameters, and optionally loads the optimizer state from files saved previously using save_state. It handles various configurations for device placement, data types, sharding, and quantization.
- Parameters
load_directory – Path to the directory containing the saved state (configuration, model weights, and potentially optimizer state).
device – The JAX device (e.g., ‘cpu’, ‘gpu’, ‘tpu’) to load the model onto. Defaults to ‘cpu’.
dtype – The data type to use for computation (e.g., jax.numpy.float32). Defaults to jax.numpy.float32.
param_dtype – The data type for the model parameters (e.g., jax.numpy.bfloat16). Defaults to jax.numpy.float32.
precision – The JAX precision level (e.g., jax.lax.Precision.HIGHEST). Defaults to None.
sharding_axis_dims – A sequence defining the dimensions of the device mesh for sharding (e.g., (1, -1, 1, 1)). Defaults to (1, -1, 1, 1).
sharding_dcn_axis_dims – Optional sequence for data-centric sharding dimensions. Defaults to None.
sharding_axis_names – Names corresponding to the sharding axes (e.g., (“dp”, “fsdp”, “tp”, “sp”)). Defaults to (“dp”, “fsdp”, “tp”, “sp”).
partition_axis – Configuration object for partitioning specific axes. Defaults to None.
shard_attention_computation – If True, shards the attention computation across devices. Defaults to True.
shard_fns – Optional mapping of parameter path tuples to custom sharding functions. Defaults to None.
backend – The backend framework to use (e.g., EasyDeLBackends.JAX). Defaults to None (auto-detected).
platform – The hardware platform (e.g., EasyDeLPlatforms.TPU). Defaults to None (auto-detected).
config_kwargs – Optional dictionary of keyword arguments to override in the loaded model configuration. Defaults to None.
model_task – The specific task type for the model (e.g., TaskType.CAUSAL_LM). Defaults to TaskType.AUTO_BIND.
auto_shard_model – If True, automatically shards the loaded model and optimizer state based on the provided sharding configuration. Defaults to False.
partition_rules – Optional tuple of partition rules (regex, PartitionSpec) to explicitly define sharding. Defaults to None (uses model config).
quantization_platform – Platform for quantization (e.g., EasyDeLPlatforms.TPU). Defaults to None.
quantization_method – Quantization method (e.g., EasyDeLQuantizationMethods.AQT). Defaults to None.
quantization_block_size – Block size for quantization methods like GPTQ. Defaults to 128.
quantization_pattern – Regex pattern to match tensor names for quantization. Defaults to None.
quantize_tensors – If True, applies quantization to the loaded tensors. Defaults to True.
verbose – If True, logs detailed information during loading. Defaults to True.
**kwargs – Additional keyword arguments passed directly to the underlying EasyDeLBaseModule.from_pretrained method.
- Returns
An EasyDeLState instance containing the loaded model, optimizer state (if found and loaded), and associated configuration.
- Raises
FileNotFoundError – If the load_directory or essential files within it (like configuration or model weights) are not found.
ValueError – If there are inconsistencies in the provided arguments or loaded configuration.
# Note – Other exceptions from underlying calls like AutoEasyDeLConfig
# or EasyDeLBaseModule.from_pretrained might also be raised. –
- merge(tree) Any[source]#
Merges a given state tree (usually parameters) with the graph definition and other state components to reconstruct the full model module.
- Parameters
tree – The pytree (e.g., nn.GraphState) containing the parameters to merge.
- Returns
The reconstructed model module.
- Return type
- merge_to_state(tree) EasyDeLState[source]#
Creates a new EasyDeLState by replacing the current graphstate with the provided tree.
- Parameters
tree – The pytree (e.g., nn.GraphState) containing the new parameters.
- Returns
A new state object with the updated graphstate.
- Return type
- property model: Any#
Reconstructs and returns the full EasyDeL model module from the state components.
- Returns
The model module instance.
- Return type
- opt_state: tp.Optional[optax.OptState]#
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- save_state(save_directory: Union[str, PathLike], float_dtype: Optional[dtype] = None, verbose: bool = True, mismatch_allowed: bool = True, save_optimizer: bool = True, enable: Optional[bool] = None)[source]#
Saves the entire EasyDeLState to a directory.
This includes saving the model parameters (using model.save_pretrained) and optionally the optimizer state.
- Parameters
save_directory (tp.Union[str, os.PathLike]) – The directory to save the state to.
float_dtype (tp.Optional[jax.numpy.dtype]) – Optional dtype to cast floating-point parameters to before saving. Defaults to None.
verbose (bool) – If True, logs information during saving. Defaults to True.
mismatch_allowed (bool) – Passed to model.save_pretrained, allows saving even if the model structure differs slightly from expected. Defaults to True.
save_optimizer (bool) – If True, saves the optimizer state. Defaults to True.
enable (tp.Optional[bool]) – If set, controls whether saving happens (True) or is skipped (False). If None, saving typically occurs only on JAX process index 0. Defaults to None.
- shard_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#
Shards the model parameters (graphstate and graphother) based on partition rules.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses model config rules. Defaults to None.
mesh (tp.Optional[Mesh], optional) – The JAX device mesh to shard across. If None, uses model’s mesh. Defaults to None.
- Returns
A new state object with sharded graphstate and graphother.
- Return type
- shard_optimizer_state(opt_state: Optional[Any] = None, partition_rules: Any = None) Any[source]#
Applies sharding to the optimizer state based on partition rules.
- Parameters
opt_state (tp.Optional[tp.Any]) – The optimizer state pytree to shard. If None, uses self.opt_state. Defaults to None.
partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses rules from the model’s config. Defaults to None.
- Returns
A new state object with the sharded opt_state.
- Return type
- Raises
ValueError – If optimizer state is not initialized (neither opt_state argument nor self.opt_state is available).
- shard_state(partition_rules: Any = None) EasyDeLState[source]#
Shards the entire state (model parameters and optimizer state) based on partition rules.
This is a convenience method that calls shard_model and shard_optimizer_state.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses rules from the model’s config. Defaults to None.
- Returns
A new state object with both model and optimizer states sharded.
- Return type
- shard_with_shape(shape) EasyDeLState[source]#
Applies sharding constraints to the entire state based on a reference shape pytree.
This method takes a pytree shape which has the same structure as the EasyDeLState but contains sharding annotations (e.g., NamedSharding) instead of actual array data. It applies these shardings as constraints to the corresponding arrays in the current state.
- Parameters
shape – A pytree with the same structure as self, containing sharding annotations.
- Returns
A new state object with sharding constraints applied.
- Return type
- property shardings#
Retrieves the sharding annotations (e.g., NamedSharding) for all components of the EasyDeLState pytree.
- Returns
A pytree with the same structure as self, containing sharding annotations or None for components without sharding.
- property size: int#
Calculates the total size in bytes of the model parameters (graphstate) and the optimizer state (opt_state).
- Returns
The total size in bytes.
- Return type
int
- tx: optax.GradientTransformation#
- class easydel.infra.__init__.LossConfig(ignore_index: int = -100, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]] = 'NUM_REAL_TARGET_TOKENS', num_labels: Optional[str] = None, problem_type: Optional[str] = None, divide_weight_sum: bool = False, shift_tokens: bool = True, break_on_nan: bool = True, reduction: Optional[Literal['none', 'mean', 'sum']] = None, num_classification_labels: Optional[int] = None, classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None)[source]#
Bases:
objectConfiguration class for customizing loss computation behavior.
- ignore_index#
Specifies a target value that is ignored and does not contribute to the loss. Defaults to -100.
- Type
int
- label_smoothing#
Amount of label smoothing to apply. 0.0 means no smoothing. Defaults to 0.0.
- Type
float
- z_loss#
Coefficient for the z-loss regularization term, which encourages logits for non-target classes to be small. Defaults to 0.0.
- Type
float
- loss_normalizing_factor#
How to normalize the loss. Can be a constant float/int, a string representation of a SpecialLossNormalizingFactor enum, or the enum itself. Defaults to “NUM_REAL_TARGET_TOKENS”.
- Type
FACTOR_TYPE
- num_labels#
The number of labels for classification tasks. Used in ForSequenceClassificationLoss. Defaults to None.
- Type
tp.Optional[int]
- problem_type#
Specifies the problem type for sequence classification (e.g., “single_label_classification”, “multi_label_classification”). Defaults to None.
- Type
tp.Optional[str]
- divide_weight_sum#
If True, divides the loss by the sum of weights, in addition to the loss_normalizing_factor. Defaults to False.
- Type
bool
- shift_tokens#
If True (typically for Causal LM), shifts the logits and labels so that the model predicts the next token. Defaults to True.
- Type
bool
- break_on_nan#
If True, raises an EasyDeLBreakRequest if a NaN is encountered during loss computation. Defaults to True.
- Type
bool
- reduction#
Specifies the reduction to apply to the loss. If None, the default reduction of the specific loss function is used. Defaults to None.
- Type
tp.Optional[tp.Literal[“none”, “mean”, “sum”]]
- num_classification_labels#
Number of labels specifically for sequence classification. Alias for num_labels. Defaults to None.
- Type
tp.Optional[int]
- classification_problem_type#
Problem type specifically for sequence classification. Alias for problem_type. Defaults to None.
- Type
tp.Optional[tp.Literal[“regression”, “single_label_classification”, “multi_label_classification”]]
- break_on_nan: bool = True#
- classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None#
- divide_weight_sum: bool = False#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- ignore_index: int = -100#
- label_smoothing: float = 0.0#
- loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]] = 'NUM_REAL_TARGET_TOKENS'#
- num_classification_labels: Optional[int] = None#
- num_labels: Optional[str] = None#
- problem_type: Optional[str] = None#
- reduction: Optional[Literal['none', 'mean', 'sum']] = None#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- shift_tokens: bool = True#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- z_loss: float = 0.0#
- class easydel.infra.__init__.PartitionAxis(data_parallel_axis: str = 'dp', fully_sharded_data_parallel_axis: str = 'fsdp', tensor_parallel_axis: str = 'tp', sequence_parallel_axis: str = 'sp', expert_parallel_axis: str = 'ep', batch_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = <eformer.common_types._Empty object>, sequence_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = <eformer.common_types._Empty object>, query_sequence_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = <eformer.common_types._Empty object>, head_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = <eformer.common_types._Empty object>, kv_head_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = <eformer.common_types._Empty object>, key_sequence_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = <eformer.common_types._Empty object>, hidden_state_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = <eformer.common_types._Empty object>, mlp_intermediate_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = <eformer.common_types._Empty object>, vocab_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = <eformer.common_types._Empty object>, expert_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = <eformer.common_types._Empty object>, expert_gate_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = None, attention_dim_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = None, attention_kv_dim_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = None, bias_head_sequence_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = None, bias_key_sequence_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = None, decode_batch_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = <eformer.common_types._Empty object>, decode_query_sequence_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = None, decode_head_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = <eformer.common_types._Empty object>, decode_kv_head_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = <eformer.common_types._Empty object>, decode_key_sequence_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = <eformer.common_types._Empty object>, decode_attention_dim_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = None, decode_attention_kv_dim_axis: ~typing.Optional[~typing.Union[~typing.Tuple[str, ...], str, ~typing.Any]] = None)[source]#
Bases:
xTreeConfiguration for partitioning model axes across a device mesh.
Defines the mesh dimension names for standard parallelism strategies and maps logical model axes to these dimensions. Allows overriding defaults.
- Mesh Dimensions Attributes:
data_parallel_axis: Name for data parallel mesh dim. Default: “dp”. fully_sharded_data_parallel_axis: Name for FSDP mesh dim. Default: “fsdp”. tensor_parallel_axis: Name for tensor parallel mesh dim. Default: “tp”. sequence_parallel_axis: Name for sequence parallel mesh dim. Default: “sp”. expert_parallel_axis: Name for expert parallel mesh dim (MoE). Default: “ep”.
- Logical Model Axes Attributes:
Maps logical tensor axes (like batch, sequence, hidden) to one or more mesh dimension names defined above, or None if not partitioned. Defaults are derived from the standard mesh dimension names but can be overridden during instantiation. For example, head_axis defaults to the value of tensor_parallel_axis (‘tp’).
batch_axis: Mesh axis for the batch dimension. sequence_axis: Mesh axis for the general sequence length dimension. query_sequence_axis: Mesh axis for the query sequence length dimension. head_axis: Mesh axis for the attention head dimension. key_sequence_axis: Mesh axis for the key/value sequence length dimension. hidden_state_axis: Mesh axis for the embedding or hidden state dimension. mlp_intermediate_axis: Mesh axis for the intermediate dimension in MLP layers. vocab_axis: Mesh axis for the vocabulary dimension. expert_axis: Mesh axis for the expert dimension. expert_gate_axis: Mesh axis for the expert gate dimension. attention_dim_axis: Mesh axis for the dimension within each attention head. bias_head_sequence_axis: Mesh axis for bias related to head and sequence dimensions. bias_key_sequence_axis: Mesh axis for bias related to key/value sequence dimensions.
decode_batch_axis: Mesh axis for the batch dimension during decoding. decode_query_sequence_axis: Mesh axis for the query sequence length during decoding. decode_head_axis: Mesh axis for the attention head dimension during decoding. decode_key_sequence_axis: Mesh axis for the key/value sequence length during decoding. decode_attention_dim_axis: Mesh axis for the dimension within each attention head during decoding.
- attention_dim_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
- attention_kv_dim_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
- batch_axis: Optional[Union[Tuple[str, ...], str, Any]] = <eformer.common_types._Empty object>#
- bias_head_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
- bias_key_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
- data_parallel_axis: str = 'dp'#
- decode_attention_dim_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
- decode_attention_kv_dim_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
- decode_batch_axis: Optional[Union[Tuple[str, ...], str, Any]] = <eformer.common_types._Empty object>#
- decode_head_axis: Optional[Union[Tuple[str, ...], str, Any]] = <eformer.common_types._Empty object>#
- decode_key_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = <eformer.common_types._Empty object>#
- decode_kv_head_axis: Optional[Union[Tuple[str, ...], str, Any]] = <eformer.common_types._Empty object>#
- decode_query_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
- expert_axis: Optional[Union[Tuple[str, ...], str, Any]] = <eformer.common_types._Empty object>#
- expert_gate_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
- expert_parallel_axis: str = 'ep'#
- fully_sharded_data_parallel_axis: str = 'fsdp'#
- head_axis: Optional[Union[Tuple[str, ...], str, Any]] = <eformer.common_types._Empty object>#
- key_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = <eformer.common_types._Empty object>#
- kv_head_axis: Optional[Union[Tuple[str, ...], str, Any]] = <eformer.common_types._Empty object>#
- mlp_intermediate_axis: Optional[Union[Tuple[str, ...], str, Any]] = <eformer.common_types._Empty object>#
- query_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = <eformer.common_types._Empty object>#
- replace(**updates)#
Returns a new instance of the dataclass with specified fields updated.
- Parameters
**updates – Keyword arguments where keys are field names and values are the new values for those fields.
- Returns
A new instance of the dataclass with the updated fields.
- resolve_spec(axes: Sequence[Optional[str]], mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__']) PartitionSpec[source]#
Generates a PartitionSpec from a sequence of semantic axis names and a mode.
Maps a sequence of semantic axis name strings (like BATCH, LENGTH) to the actual mesh axis names defined in this PartitionAxis instance, considering the current runtime mode (e.g., training vs. generation).
- Parameters
axes – A sequence of semantic axis name strings (e.g., [BATCH, LENGTH, HEAD]) or None (or “_”) for axes that shouldn’t be sharded.
mode – The current operational mode (e.g., MODE_TRAIN, MODE_DECODE) which determines if generation-specific rules should be applied.
- Returns
A jax.sharding.PartitionSpec instance representing the sharding for the given sequence of axes.
- Raises
ValueError – If an unknown semantic axis name is encountered or if a resolved axis rule is still NOT_GIVEN (should be caught by _safety_check but included for robustness).
LookupError – If an internal attribute name derived from the semantic map isn’t found in the instance (shouldn’t happen with correct class definition).
- sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = <eformer.common_types._Empty object>#
- sequence_parallel_axis: str = 'sp'#
- tensor_parallel_axis: str = 'tp'#
- vocab_axis: Optional[Union[Tuple[str, ...], str, Any]] = <eformer.common_types._Empty object>#
- class easydel.infra.__init__.PyTree[source]#
Bases:
_PyTreeNodeBaseBase class for mutable PyTree dataclasses.
Inheriting from this class automatically applies the auto_pytree decorator to the subclass, registering it as a JAX PyTree.
- easydel.infra.__init__.auto_pytree(cls: Optional[Type[T]] = None, meta_fields: Optional[Tuple[str, ...]] = None, json_serializable: bool = True, frozen: bool = False)[source]#
A class decorator that automatically registers a dataclass as a JAX PyTree.
It uses dataclasses.dataclass to make the class a dataclass if it isn’t already, determines which fields are data (PyTree children) and which are metadata, and registers the class with jax.tree_util.register_dataclass.
Fields are considered metadata if: - They are explicitly listed in meta_fields. - They are marked with field(pytree_node=False). - Their type hint suggests they are non-JAX types (checked by _is_non_jax_type).
- Parameters
cls – The class to be decorated.
meta_fields – A tuple of field names to always treat as metadata.
json_serializable – If True (default), adds to_dict, from_dict, to_json, and from_json methods to the class.
frozen – If True, makes the dataclass frozen (immutable). Defaults to False.
- Returns
The decorated class, registered as a PyTree.