easydel.infra.base_config#

Base configuration classes for EasyDeL models.

This module provides the foundational configuration system for all EasyDeL models, extending HuggingFace’s PretrainedConfig with EasyDeL-specific features like attention mechanisms, quantization, gradient checkpointing, and hardware optimization.

Classes:

EasyDeLBaseConfig: Main configuration class with all EasyDeL features EasyDeLBaseConfigDict: Simplified dictionary-based configuration

Key Features:
  • Multiple attention mechanism support (flash, ring, etc.)

  • Quantization configuration

  • Gradient checkpointing policies

  • Hardware abstraction and optimization

  • RoPE (Rotary Position Embedding) configuration

  • Custom kernel support

Example

>>> from easydel.infra import EasyDeLBaseConfig
>>> config = EasyDeLBaseConfig(
...     hidden_size=768,
...     num_attention_heads=12,
...     attention_mechanism="flash",
...     gradient_checkpointing_policy="",
...     use_hardware_abstraction=True
... )
class easydel.infra.base_config.EasyDeLBaseConfig(sharding_axis_dims: tp.Sequence[int] = (1, -1, 1, 1, 1), sharding_dcn_axis_dims: tp.Sequence[int] | None = None, sharding_axis_names: tp.Sequence[str] = ('dp', 'fsdp', 'ep', 'tp', 'sp'), attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS = 'vanilla', decode_attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS = None, blocksize_k: int = 128, blocksize_q: int = 128, blocksize_b: int = 1, moe_tiling_size_batch: int = 4, moe_tiling_size_seqlen: int = 128, moe_tiling_size_dim: int = 128, partition_axis: PartitionAxis = PartitionAxis(data_parallel_axis='dp', fully_sharded_data_parallel_axis='fsdp', tensor_parallel_axis='tp', sequence_parallel_axis='sp', expert_parallel_axis='ep', batch_axis=('fsdp', 'dp'), sequence_axis='sp', query_sequence_axis='sp', head_axis='tp', kv_head_axis='tp', key_sequence_axis='sp', hidden_state_axis='tp', mlp_intermediate_axis='tp', vocab_axis='tp', expert_axis='ep', expert_gate_axis=None, attention_dim_axis=None, attention_kv_dim_axis=None, bias_head_sequence_axis=None, bias_key_sequence_axis=None, decode_batch_axis=('fsdp', 'dp'), decode_query_sequence_axis=None, decode_head_axis='tp', decode_kv_head_axis='tp', decode_key_sequence_axis='sp', decode_attention_dim_axis=None, decode_attention_kv_dim_axis=None), use_sharded_kv_caching: bool = False, use_sharding_constraint: bool = False, backend: EasyDeLBackends | None = None, platform: EasyDeLPlatforms | None = None, easy_method: tp.Literal['train', 'serve', 'convert'] = 'train', bits: int | None = None, scan_ring_attention: bool = True, scan_attention_layers: bool = False, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, sequence_axis_name: str = 'sp', gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, gradient_checkpointing_targets: list[AVAILABLE_GRADIENT_CHECKPOINT_TARGETS] | None = None, precompute_masks: bool = True, kv_cache_quantization_config: EasyDeLQuantizationConfig | None = None, quantization_config: EasyDeLQuantizationConfig | None = None, kv_cache_sharding_sequence_axis_name: str | tuple[str, ...] = 'sp', flash_attention_backward_pass_impl: tp.Literal['triton', 'xla'] = 'triton', attn_dtype: jnp.dtype = <class 'jax.numpy.bfloat16'>, kvdtype: jnp.dtype | None = None, attn_softmax_dtype: jnp.dtype = <class 'jax.numpy.float32'>, fcm_max_ratio: float = 0.0, fcm_min_ratio: float = 0.0, hardware_abstraction: bool = False, pallas_m_block_size: int = 128, pallas_k_block_size: int = 128, pallas_n_block_size: int = 128, moe_method: AVAILABLE_MOE_METHODS = 'fused_moe', moe_force_xla_gmm: bool = False, use_expert_tensor_mode: bool = False, use_ring_of_experts: bool = False, fsdp_is_ep_bound: bool = True, sp_is_ep_bound: bool = True, operation_configs: dict[str, BaseOperationConfig] | None = None, **kwargs)[source]#

Bases: PretrainedConfig

Base configuration shared across EasyDeL models.

Extends transformers.PretrainedConfig with distributed sharding metadata, attention kernel selection, quantization knobs, RoPE helpers, and hardware abstraction flags used for both training and serving.

Parameters
  • sharding_axis_dims – Parallelism sizes for (dp, fsdp, ep, tp, sp). -1 consumes all remaining devices. Defaults to (1, -1, 1, 1, 1).

  • sharding_dcn_axis_dims – Optional mesh sizes for DCN slices when running multi-host or multi-slice setups.

  • sharding_axis_names – Logical mesh axis names, defaults to ("dp", "fsdp", "ep", "tp", "sp").

  • attn_mechanism – Attention implementation to use during training/forward passes.

  • decode_attn_mechanism – Attention implementation to use during decoding (falls back to attn_mechanism if left as None).

  • blocksize_k – Key block size for attention kernels. Defaults to 128.

  • blocksize_q – Query block size for attention kernels. Defaults to 128.

  • blocksize_b – Batch/block size used by some attention backends. Defaults to 1.

  • moe_tiling_size_batch – Batch tiling used by MoE kernels. Defaults to 4.

  • moe_tiling_size_seqlen – Sequence length tiling for MoE kernels. Defaults to 128.

  • moe_tiling_size_dim – Hidden dimension tiling for MoE kernels. Defaults to 128.

  • partition_axisPartitionAxis describing how logical axes map to the mesh.

  • use_sharded_kv_caching – Whether to shard KV cache placement instead of replicating.

  • use_sharding_constraint – Insert explicit sharding constraints during model build.

  • backend – Explicit JAX backend (falls back to jax.default_backend()).

  • platform – Platform hint for kernel selection (defaults to "triton" on GPU, otherwise "jax").

  • easy_method – Workflow context ("train", "serve", or "convert").

  • bits – Optional quantization bit width for weights.

  • scan_ring_attention – Use scanning for ring attention implementations.

  • scan_attention_layers – Apply scan to attention blocks to save memory.

  • use_scan_mlp – Apply scan to MLP blocks.

  • scan_mlp_chunk_size – Chunk size when scanning MLPs. Defaults to 1024.

  • sequence_axis_name – Name of the sequence/attention axis. Defaults to "sp".

  • gradient_checkpointing – Gradient checkpointing policy enum/string.

  • gradient_checkpointing_targets – Optional list of target names to include or exclude when using selective checkpointing policies.

  • precompute_masks – Whether to precompute and cache causal masks on the mesh.

  • kv_cache_quantization_config – Quantization config for KV cache tensors. Pass None to disable.

  • quantization_config – Quantization config for linear layers. Pass None to disable.

  • kv_cache_sharding_sequence_axis_name – Axis (or axes) used when sharding the KV cache.

  • flash_attention_backward_pass_impl – Backward kernel for flash attention ("triton" or "xla"). Defaults to "triton".

  • attn_dtype – Attention activation dtype. Defaults to jnp.bfloat16.

  • kvdtype – KV cache dtype. Defaults to attn_dtype when None.

  • attn_softmax_dtype – Softmax computation dtype. Defaults to jnp.float32.

  • fcm_max_ratio – Maximum ratio used when sampling forgetful causal masks.

  • fcm_min_ratio – Minimum ratio used when sampling forgetful causal masks.

  • hardware_abstraction – Enable EasyDeL hardware abstraction and custom kernels.

  • pallas_m_block_size – Matmul M dimension block size for Pallas kernels.

  • pallas_k_block_size – Matmul K dimension block size for Pallas kernels.

  • pallas_n_block_size – Matmul N dimension block size for Pallas kernels.

  • moe_method – Mixture-of-experts implementation to use.

  • moe_force_xla_gmm – Force XLA GMM kernels for MoE even when fused kernels exist.

  • use_ring_of_experts – Whether to dispatch experts with a ring topology.

  • use_expert_tensor_mode – Treat experts as an additional tensor-parallel axis.

  • fsdp_is_ep_bound – Fold the FSDP axis into the expert axis when building expert meshes.

  • sp_is_ep_bound – Fold the sequence-parallel axis into the expert axis when building expert meshes.

  • **kwargs – Forwarded to PretrainedConfig.

Raises

UserWarning – If KV-cache quantization is requested together with sharded KV caching.

add_basic_configurations(sharding_axis_dims: ~typing.Sequence[int] = <eformer.common_types._Empty object>, sharding_dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = <eformer.common_types._Empty object>, sharding_axis_names: ~typing.Sequence[str] = <eformer.common_types._Empty object>, attn_mechanism: ~typing.Literal['auto', 'vanilla', 'flash_attn2', 'blocksparse', 'ring', 'cudnn', 'blockwise', 'sdpa', 'autoregressive_decodeattn', 'ragged_page_attention_v2', 'ragged_page_attention_v3', 'page_attention'] = <eformer.common_types._Empty object>, decode_attn_mechanism: ~typing.Literal['auto', 'vanilla', 'flash_attn2', 'blocksparse', 'ring', 'cudnn', 'blockwise', 'sdpa', 'autoregressive_decodeattn', 'ragged_page_attention_v2', 'ragged_page_attention_v3', 'page_attention'] = <eformer.common_types._Empty object>, blocksize_k: int = <eformer.common_types._Empty object>, blocksize_q: int = <eformer.common_types._Empty object>, blocksize_b: int = <eformer.common_types._Empty object>, moe_tiling_size_batch: int = <eformer.common_types._Empty object>, moe_tiling_size_seqlen: int = <eformer.common_types._Empty object>, moe_tiling_size_dim: int = <eformer.common_types._Empty object>, partition_axis: ~eformer.escale.partition.manager.PartitionAxis = <eformer.common_types._Empty object>, use_sharded_kv_caching: bool = <eformer.common_types._Empty object>, backend: easydel.infra.etils.EasyDeLBackends | None = <eformer.common_types._Empty object>, platform: easydel.infra.etils.EasyDeLPlatforms | None = <eformer.common_types._Empty object>, easy_method: ~typing.Literal['train', 'serve', 'convert'] = <eformer.common_types._Empty object>, bits: int | None = <eformer.common_types._Empty object>, scan_ring_attention: bool = <eformer.common_types._Empty object>, scan_attention_layers: bool = <eformer.common_types._Empty object>, use_sharding_constraint: bool = <eformer.common_types._Empty object>, use_scan_mlp: bool = <eformer.common_types._Empty object>, scan_mlp_chunk_size: int = <eformer.common_types._Empty object>, sequence_axis_name: str = <eformer.common_types._Empty object>, gradient_checkpointing: ~easydel.infra.etils.EasyDeLGradientCheckPointers = <eformer.common_types._Empty object>, gradient_checkpointing_targets: list[typing.Literal['attn_dense', 'attn_key', 'attn_key_value', 'attn_output', 'attn_qkv', 'attn_query', 'attn_receptance', 'attn_value', 'attn_weights', 'embeddings', 'layer_output', 'lm_head_output', 'mlp_down', 'mlp_gate', 'mlp_output', 'mlp_up', 'model_output', 'moe_expert_output', 'moe_gate_logits', 'moe_output', 'moe_router_logits', 'normed_input', 'residual']] | None = <eformer.common_types._Empty object>, precompute_masks: bool = <eformer.common_types._Empty object>, kv_cache_quantization_config: easydel.layers.quantization.quantizers.EasyDeLQuantizationConfig | None = <eformer.common_types._Empty object>, quantization_config: easydel.layers.quantization.quantizers.EasyDeLQuantizationConfig | None = <eformer.common_types._Empty object>, kv_cache_sharding_sequence_axis_name: str | tuple[str, ...] = <eformer.common_types._Empty object>, flash_attention_backward_pass_impl: ~typing.Literal['triton', 'xla'] = <eformer.common_types._Empty object>, attn_dtype: ~numpy.dtype = <eformer.common_types._Empty object>, kvdtype: numpy.dtype | None = <eformer.common_types._Empty object>, attn_softmax_dtype: ~numpy.dtype = <eformer.common_types._Empty object>, hardware_abstraction: bool = <eformer.common_types._Empty object>, pallas_m_block_size: int = <eformer.common_types._Empty object>, pallas_k_block_size: int = <eformer.common_types._Empty object>, pallas_n_block_size: int = <eformer.common_types._Empty object>, moe_method: ~typing.Literal['fused_moe', 'standard_moe', 'dense_moe'] = <eformer.common_types._Empty object>, moe_force_xla_gmm: bool = <eformer.common_types._Empty object>, use_ring_of_experts: bool = <eformer.common_types._Empty object>, use_expert_tensor_mode: bool = <eformer.common_types._Empty object>, fsdp_is_ep_bound: bool = <eformer.common_types._Empty object>, sp_is_ep_bound: bool = <eformer.common_types._Empty object>, **kwargs)[source]#

Populate baseline EasyDeL attributes on an existing config instance.

Each argument mirrors the constructor but is optional: passing NOT_GIVEN leaves any existing attribute untouched, while a provided value overwrites the current setting. If an attribute is missing entirely, a sensible default is applied via set_attrs_smartly. This helper is used by derived configs (and their sub_configs) to keep sharding/attention/quantization knobs in sync without re-implementing initialization logic.

Parameters
  • sharding_axis_dims – Fallback mesh sizes for (dp, fsdp, ep, tp, sp), defaulting to (1, -1, 1, 1, 1).

  • sharding_dcn_axis_dims – Optional DCN mesh sizes (default None).

  • sharding_axis_names – Mesh axis labels, default ("dp", "fsdp", "ep", "tp", "sp").

  • attn_mechanism – Attention mechanism to use (default "vanilla").

  • decode_attn_mechanism – Optional decode-time attention mechanism.

  • blocksize_k – Attention key block size, default 512 when unset.

  • blocksize_q – Attention query block size, default 512 when unset.

  • blocksize_b – Batch/block size used by attention kernels (default 1).

  • moe_tiling_size_batch – Batch tiling for MoE kernels (default 4).

  • moe_tiling_size_seqlen – Sequence tiling for MoE kernels (default 128).

  • moe_tiling_size_dim – Hidden-dim tiling for MoE kernels (default 128).

  • partition_axis – PartitionAxis describing logical mesh layout (default PartitionAxis()).

  • use_sharded_kv_caching – Whether to shard KV caches (default False).

  • backend – Backend string, default None (falls back to JAX default).

  • platform – Platform hint, default "jax".

  • easy_method – EasyDeL execution mode, default EasyMethod.TRAIN.

  • bits – Optional quantization bit width, default None.

  • scan_ring_attention – Enable scan for ring attention (default True).

  • scan_attention_layers – Enable scan for attention blocks (default True).

  • use_sharding_constraint – Insert sharding constraints (default False).

  • use_scan_mlp – Enable scan for MLPs (default False).

  • scan_mlp_chunk_size – Chunk size for scanned MLPs (default 1024).

  • sequence_axis_name – Label for the sequence/attention axis (default "sp").

  • gradient_checkpointing – Gradient checkpointing policy (default EasyDeLGradientCheckPointers.NONE).

  • gradient_checkpointing_targets – Optional list of checkpoint targets to include/exclude (default None).

  • precompute_masks – Whether to precompute and cache masks (default True).

  • kv_cache_quantization_config – KV cache quantization config (default None = no quantization).

  • quantization_config – Linear-layer quantization config (default None = no quantization).

  • kv_cache_sharding_sequence_axis_name – Axis name(s) for KV cache sharding (default "sp").

  • flash_attention_backward_pass_impl – Backward kernel for flash attention (default "triton").

  • attn_dtype – Attention activation dtype (default jnp.float32).

  • kvdtype – KV cache dtype (defaults to attn_dtype when unset).

  • attn_softmax_dtype – Softmax computation dtype (default jnp.float32).

  • hardware_abstraction – Toggle EasyDeL hardware abstraction (default DEFAULT_HARDWARE_ABSTRACTION).

  • pallas_m_block_size – Pallas matmul M block size (default DEFAULT_PALLAS_M_BLOCK_SIZE).

  • pallas_k_block_size – Pallas matmul K block size (default DEFAULT_PALLAS_K_BLOCK_SIZE).

  • pallas_n_block_size – Pallas matmul N block size (default DEFAULT_PALLAS_N_BLOCK_SIZE).

  • moe_method – MoE implementation to use (default DEFAULT_MOE_METHOD).

  • moe_force_xla_gmm – Force XLA GMM kernels for MoE (default False).

  • use_ring_of_experts – Dispatch experts with a ring topology (default RING_EXPERTS).

  • use_expert_tensor_mode – Treat experts as a tensor-parallel axis (default EXPERT_TP_MODE).

  • fsdp_is_ep_bound – Fold FSDP into the expert axis when building expert meshes.

  • sp_is_ep_bound – Fold sequence-parallel into the expert axis when building expert meshes.

  • **kwargs – Extra attributes to attach to this config and any defined sub_configs.

attach_custom_arguments(**kwargs)[source]#

Attaches custom arguments as attributes to the configuration.

Parameters

**kwargs – Arbitrary key-value pairs to attach as attributes.

property auto_expert_mesh: Mesh#

Get the mesh for expert parallelism with automatic axis types.

Similar to expert_mesh, but uses jax.sharding.AxisType.Auto for all axes, allowing JAX to automatically determine the optimal sharding strategy based on the computation graph.

Returns

A mesh with auto axis types configured for

expert parallelism with (dp, ep, tp) axis ordering.

Return type

jax.sharding.Mesh

static create_mesh(sharding_axis_dims: Sequence[int] = (1, -1, 1, 1, 1), sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'ep', 'tp', 'sp'), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, process_is_granule: bool = False, should_sort_granules_by_key: bool = True, allow_split_physical_axes: bool = True, backend: str | None = None)[source]#

Creates a JAX device mesh for distributed model execution.

This function constructs a multi-dimensional mesh of devices that defines how model parameters and computations are distributed across hardware. The mesh axes correspond to different parallelism strategies:

  • dp (data parallel): Replicate model across data batches

  • fsdp (fully sharded data parallel): Shard parameters and optimizer states

  • ep (expert parallel): Distribute experts in MoE models

  • tp (tensor parallel): Partition individual weight matrices

  • sp (sequence parallel): Split sequence dimension across devices

Parameters
  • sharding_axis_dims – Size of each parallelism dimension. Use -1 to automatically fill remaining devices. Default: (1, -1, 1, 1, 1) means all devices go to FSDP axis.

  • sharding_axis_names – Names for each mesh axis. Must match length of sharding_axis_dims. Default: (“dp”, “fsdp”, “ep”, “tp”, “sp”).

  • sharding_dcn_axis_dims – Dimensions for data center network (DCN) sharding. Used for multi-host/multi-slice setups. Default: None.

  • process_is_granule – Deprecated parameter, not used.

  • should_sort_granules_by_key – Whether to sort device granules by key for deterministic ordering. Default: True.

  • allow_split_physical_axes – Whether to allow splitting physical device axes when mapping to logical mesh axes. Default: True.

  • backend – Backend platform to create mesh for (‘gpu’, ‘tpu’, etc.). If None or empty string, uses default backend.

Returns

A JAX Mesh object configured for distributed execution with the specified parallelism dimensions.

Example

>>> # Create mesh with DP=4, TP=2 on 8 GPUs
>>> mesh = EasyDeLBaseConfig.create_mesh(
...     sharding_axis_dims=(4, 1, 1, 2, 1),
...     sharding_axis_names=("dp", "fsdp", "ep", "tp", "sp"),
... )
>>> # mesh.shape = {'dp': 4, 'fsdp': 1, 'ep': 1, 'tp': 2, 'sp': 1}
property expert_abstract_mesh: AbstractMesh#

Get the abstract mesh descriptor for expert parallelism.

Returns an abstract mesh that matches the expert_mesh axis sizes and names. Abstract meshes are lightweight representations used for sharding specification without device assignment.

Returns

An abstract mesh descriptor with the same

axis configuration as expert_mesh.

Return type

jax.sharding.AbstractMesh

property expert_mesh: Mesh#

Get the mesh configuration for expert parallelism.

Creates a mesh with expert-parallel axes folded according to the fsdp_is_ep_bound and sp_is_ep_bound configuration flags. This mesh is used for MoE (Mixture of Experts) models to distribute experts across devices with explicit axis types.

Returns

A mesh with explicit axis types configured for

expert parallelism with (dp, ep, tp) axis ordering.

Return type

jax.sharding.Mesh

classmethod from_pretrained(pretrained_model_name_or_path: str | os.PathLike, cache_dir: str | os.PathLike | None = None, force_download: bool = False, local_files_only: bool = False, token: str | bool | None = None, revision: str = 'main', **kwargs) PretrainedConfig[source]#

Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.

Parameters
  • pretrained_model_name_or_path (str or os.PathLike) –

    This can be either:

    • a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.

    • a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.

    • a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.

  • cache_dir (str or os.PathLike, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.

  • force_download (bool, optional, defaults to False) – Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.

  • resume_download – Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.

  • proxies (Dict[str, str], optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.

  • token (str or bool, optional) – The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).

  • revision (str, optional, defaults to “main”) –

    The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.

    <Tip>

    To test a pull request you made on the Hub, you can pass `revision=”refs/pr/<pr_number>”.

    </Tip>

  • return_unused_kwargs (bool, optional, defaults to False) –

    If False, then this function returns just the final configuration object.

    If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.

  • subfolder (str, optional, defaults to “”) – In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.

  • kwargs (Dict[str, tp.Any], optional) – The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.

Returns

The configuration object instantiated from this pretrained model.

Return type

[PretrainedConfig]

Examples:

>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
>>> # derived class: BertConfig
>>> config = BertConfig.from_pretrained(
...   "google-bert/bert-base-uncased"
>>> )  # Download configuration from huggingface.co and cache.
>>> config = BertConfig.from_pretrained(
...   "./test/saved_model/"
>>> )  # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
>>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
>>> config = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased", output_attentions=True, foo=False
>>> )
>>> assert config.output_attentions == True
>>> config, unused_kwargs = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased",
...  output_attentions=True,
...  foo=False,
...  return_unused_kwargs=True,
>>> )
>>> assert config.output_attentions == True
>>> assert unused_kwargs == {"foo": False}

```

get_axis_dims() Sequence[int][source]#

Returns the device mesh axis dimensions for parallelism.

Returns

Sequence of integers specifying the size of each parallelism axis. Typically (dp_size, fsdp_size, ep_size, tp_size, sp_size). Value of -1 means “use all remaining devices for this axis”.

Example

>>> config.sharding_axis_dims = (2, 4, 1, 1, 1)
>>> dims = config.get_axis_dims()
>>> # dims = (2, 4, 1, 1, 1) - 2 data parallel, 4 FSDP, rest replicated
get_axis_names() Sequence[str][source]#

Returns the logical names for each device mesh axis.

Returns

Sequence of strings naming each parallelism axis. Typically (“dp”, “fsdp”, “ep”, “tp”, “sp”) for data parallel, fully sharded data parallel, expert parallel, tensor parallel, and sequence parallel respectively.

Example

>>> names = config.get_axis_names()
>>> # names = ('dp', 'fsdp', 'ep', 'tp', 'sp')
get_backend() str[source]#

Returns the JAX backend platform being used.

Retrieves the configured backend (e.g., ‘gpu’, ‘tpu’, ‘cpu’), or falls back to the default JAX backend if not explicitly set.

Returns

‘gpu’, ‘tpu’, ‘cpu’.

Return type

Backend platform string. Common values

Example

>>> config = EasyDeLBaseConfig(backend='gpu')
>>> config.get_backend()
'gpu'
>>>
>>> # With no backend set, returns JAX default
>>> config = EasyDeLBaseConfig(backend='')
>>> config.get_backend()  # Might return 'gpu', 'tpu', etc.
get_basic_causal_mask(*args, **kwargs)[source]#

Gets or creates the basic causal attention mask.

Creates a causal mask for the maximum position embeddings and places it on the appropriate device with sharding.

Returns

ModuleCaches containing the causal mask, or False if masks are not precomputed.

get_basic_frequencies(head_size: int | None = None, rotary_dim: int | None = None, base: float | None = None) ModuleCaches[source]#

Compute frequencies for rotary embeddings placed on the configured mesh.

Parameters
  • head_size – Attention head size (defaults to self.head_dim).

  • rotary_dim – Number of rotary dimensions (defaults to head_size).

  • base – Optional base for frequency computation (defaults to self.rope_theta).

Returns

ModuleCaches containing the frequencies sharded with NamedSharding.

get_basic_inv_frequencies(head_size: int | None = None, rotary_dim: int | None = None, base: float | None = None, partial_rotary_factor: float = 1.0) ModuleCaches[source]#

Compute inverse frequencies for rotary embeddings.

Parameters
  • head_size – Attention head size (defaults to self.head_dim).

  • rotary_dim – Number of rotary dimensions (defaults to head_size).

  • base – Optional base for frequency computation (defaults to self.rope_theta).

  • partial_rotary_factor – Ratio of the head dimension to apply RoPE to.

Returns

ModuleCaches wrapping the computed frequency tensor.

get_basic_rope(dtype: Union[Array, ndarray, bool, number], head_size: int, rotary_dim: int | None = None, is_neox_style: bool = True, base: float | None = None)[source]#

Return a rotary position embedding function configured for this model.

Parameters
  • dtype – Target dtype for the generated embeddings.

  • head_size – Attention head size used to derive the rotary dimension.

  • rotary_dim – Number of rotary dimensions (defaults to head_size).

  • is_neox_style – Whether to generate NeoX-style rotary embeddings.

  • base – Optional base used for frequency computation (defaults to self.rope_theta).

Returns

Callable from get_rope ready to be applied to query/key tensors.

get_fcm_mask(batch_size, seq_length, deterministic: bool)[source]#

Generates a Forgetful Causal Mask (FCM) for training.

FCM randomly drops causal constraints during training to improve model robustness. Only applied in non-deterministic mode.

Parameters
  • batch_size – Number of sequences in the batch.

  • seq_length – Length of each sequence.

  • deterministic – If True, returns None (no FCM applied).

Returns

Boolean mask array or None if deterministic or FCM not configured.

get_mask_details() dict[int, AttnMaskDetail] | None[source]#

Get attention mask details for each layer.

Retrieves layer-specific attention mask configurations, which is particularly useful for models with heterogeneous attention patterns (e.g., models using different attention types per layer like sliding window attention in some layers and full attention in others).

Returns

A dictionary mapping layer indices

to their corresponding AttnMaskDetail configurations, or None if the model doesn’t define layer-specific mask types.

Return type

dict[int, AttnMaskDetail] | None

get_partition_rules(*args, **kwargs)[source]#

Gets the parameter sharding partition rules for the model.

Partition rules define how model parameters should be sharded across the device mesh. Each rule maps a parameter name pattern (regex) to a PartitionSpec that specifies which mesh axes the parameter dimensions should be distributed across.

This method must be implemented by model-specific configuration classes.

Parameters
  • *args – Positional arguments (model-specific).

  • **kwargs – Keyword arguments (model-specific).

Returns

Tuple of (pattern, PartitionSpec) pairs defining how to shard parameters. For example: ((“model/embed.*”, PartitionSpec(“tp”, None)),

(“model/layers/d+/attn/.*”, PartitionSpec(None, “tp”)))

Raises

NotImplementedError – This base class does not provide default partition rules. Subclasses must implement this method.

Example

>>> class MyModelConfig(EasyDeLBaseConfig):
...     def get_partition_rules(self):
...         return (
...             ("embed.*", PartitionSpec("tp", None)),
...             ("attn.*", PartitionSpec(None, "tp", None)),
...             ("mlp.*", PartitionSpec(None, "tp")),
...         )
property granted_freq_max_position_embedding: int#

Return the max position embedding allowed for frequency-based caches.

property granted_mask_max_position_embedding: int#

Return the max position embedding allowed for mask precomputation.

jax_mesh()[source]#

Deprecated method for getting the JAX mesh.

Deprecated:

Use mesh property or get_mesh() method instead.

Returns

JAX device mesh.

property mesh#

Gets or creates the JAX device mesh for this configuration.

This property lazily constructs a device mesh from the configuration’s sharding parameters. Once created, the mesh is cached for reuse. The mesh can be explicitly set using set_model_mesh() to override the auto-generated one.

The mesh is constructed from: - sharding_axis_dims: Device counts per axis - sharding_axis_names: Logical names for each axis - sharding_dcn_axis_dims: Multi-host configuration (if applicable) - Various granule sorting and axis splitting options

Returns

JAX Mesh object defining the device topology for distributed execution. The mesh axes correspond to parallelism strategies (dp, fsdp, ep, tp, sp).

Note

If a custom mesh was set via set_model_mesh(), that mesh is returned instead of creating a new one.

Example

>>> config = EasyDeLBaseConfig(
...     sharding_axis_dims=(2, 1, 1, 4, 1),
...     sharding_axis_names=("dp", "fsdp", "ep", "tp", "sp"),
... )
>>> mesh = config.mesh
>>> # mesh.shape = {'dp': 2, 'fsdp': 1, 'ep': 1, 'tp': 4, 'sp': 1}
property partition_manager: PartitionManager#

Gets the partition manager for this configuration.

The PartitionManager handles translation between logical axis names (like ‘dp’, ‘tp’) and their actual configurations in the device mesh. It provides utilities for resolving partition specifications and managing distributed execution.

Returns

PartitionManager instance configured with this config’s partition_axis. If partition_axis is None, creates a default PartitionAxis() first.

Example

>>> config = EasyDeLBaseConfig()
>>> pm = config.partition_manager
>>> # Use partition manager to resolve sharding specs
>>> spec = pm.resolve(axes=["dp", "tp"], mode="train", shape=(8, 1024))
read_basics_from_config(config: EasyDeLBaseConfig)[source]#

Reads and applies basic configuration attributes from another config instance.

Copies EasyDeL-specific attributes like sharding, attention mechanism, quantization settings, etc. from the provided config.

Parameters

config – Source configuration to read attributes from.

save_pretrained(save_directory: str | os.PathLike | eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath, push_to_hub: bool = False, **kwargs)[source]#

Save a configuration object to the directory save_directory, so that it can be re-loaded using the [~PretrainedConfig.from_pretrained] class method.

Parameters
  • save_directory (str or os.PathLike) – Directory where the configuration JSON file will be saved (will be created if it does not exist).

  • push_to_hub (bool, optional, defaults to False) – Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with repo_id (will default to the name of save_directory in your namespace).

  • kwargs (Dict[str, Any], optional) – Additional key word arguments passed along to the [~utils.PushToHubMixin.push_to_hub] method.

set_model_mesh(mesh: Mesh)[source]#

Sets a custom mesh for the model, overriding the auto-generated one.

Parameters

mesh – JAX device mesh to use for this model.

to_dict() dict[str, Any][source]#

Serialize config to a dictionary while temporarily hiding forbidden types.

to_diff_dict() dict[str, Any][source]#

Removes all attributes from the configuration that correspond to the default config attributes for better readability, while always retaining the config attribute from the class. Serializes to a Python dictionary.

Returns

Dictionary of all the attributes that make up this configuration instance.

Return type

dict[str, Any]

to_json_file(json_file_path: str | os.PathLike | eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath, use_diff: bool = True)[source]#

Save this instance to a JSON file.

Parameters
  • json_file_path (str or os.PathLike) – Path to the JSON file in which this configuration instance’s parameters will be saved.

  • use_diff (bool, optional, defaults to True) – If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.

class easydel.infra.base_config.EasyDeLBaseConfigDict[source]#

Bases: TypedDict

Base configuration shared across EasyDeL models.

Extends transformers.PretrainedConfig with distributed sharding metadata, attention kernel selection, quantization knobs, RoPE helpers, and hardware abstraction flags used for both training and serving.

Parameters
  • sharding_axis_dims – Parallelism sizes for (dp, fsdp, ep, tp, sp). -1 consumes all remaining devices. Defaults to (1, -1, 1, 1, 1).

  • sharding_dcn_axis_dims – Optional mesh sizes for DCN slices when running multi-host or multi-slice setups.

  • sharding_axis_names – Logical mesh axis names, defaults to ("dp", "fsdp", "ep", "tp", "sp").

  • attn_mechanism – Attention implementation to use during training/forward passes.

  • decode_attn_mechanism – Attention implementation to use during decoding (falls back to attn_mechanism if left as None).

  • blocksize_k – Key block size for attention kernels. Defaults to 128.

  • blocksize_q – Query block size for attention kernels. Defaults to 128.

  • blocksize_b – Batch/block size used by some attention backends. Defaults to 1.

  • moe_tiling_size_batch – Batch tiling used by MoE kernels. Defaults to 4.

  • moe_tiling_size_seqlen – Sequence length tiling for MoE kernels. Defaults to 128.

  • moe_tiling_size_dim – Hidden dimension tiling for MoE kernels. Defaults to 128.

  • partition_axisPartitionAxis describing how logical axes map to the mesh.

  • use_sharded_kv_caching – Whether to shard KV cache placement instead of replicating.

  • use_sharding_constraint – Insert explicit sharding constraints during model build.

  • backend – Explicit JAX backend (falls back to jax.default_backend()).

  • platform – Platform hint for kernel selection (defaults to "triton" on GPU, otherwise "jax").

  • easy_method – Workflow context ("train", "serve", or "convert").

  • bits – Optional quantization bit width for weights.

  • scan_ring_attention – Use scanning for ring attention implementations.

  • scan_attention_layers – Apply scan to attention blocks to save memory.

  • use_scan_mlp – Apply scan to MLP blocks.

  • scan_mlp_chunk_size – Chunk size when scanning MLPs. Defaults to 1024.

  • sequence_axis_name – Name of the sequence/attention axis. Defaults to "sp".

  • gradient_checkpointing – Gradient checkpointing policy enum/string.

  • gradient_checkpointing_targets – Optional list of target names to include or exclude when using selective checkpointing policies.

  • precompute_masks – Whether to precompute and cache causal masks on the mesh.

  • kv_cache_quantization_config – Quantization config for KV cache tensors. Pass None to disable.

  • quantization_config – Quantization config for linear layers. Pass None to disable.

  • kv_cache_sharding_sequence_axis_name – Axis (or axes) used when sharding the KV cache.

  • flash_attention_backward_pass_impl – Backward kernel for flash attention ("triton" or "xla"). Defaults to "triton".

  • attn_dtype – Attention activation dtype. Defaults to jnp.bfloat16.

  • kvdtype – KV cache dtype. Defaults to attn_dtype when None.

  • attn_softmax_dtype – Softmax computation dtype. Defaults to jnp.float32.

  • fcm_max_ratio – Maximum ratio used when sampling forgetful causal masks.

  • fcm_min_ratio – Minimum ratio used when sampling forgetful causal masks.

  • hardware_abstraction – Enable EasyDeL hardware abstraction and custom kernels.

  • pallas_m_block_size – Matmul M dimension block size for Pallas kernels.

  • pallas_k_block_size – Matmul K dimension block size for Pallas kernels.

  • pallas_n_block_size – Matmul N dimension block size for Pallas kernels.

  • moe_method – Mixture-of-experts implementation to use.

  • moe_force_xla_gmm – Force XLA GMM kernels for MoE even when fused kernels exist.

  • use_ring_of_experts – Whether to dispatch experts with a ring topology.

  • use_expert_tensor_mode – Treat experts as an additional tensor-parallel axis.

  • fsdp_is_ep_bound – Fold the FSDP axis into the expert axis when building expert meshes.

  • sp_is_ep_bound – Fold the sequence-parallel axis into the expert axis when building expert meshes.

  • **kwargs – Forwarded to PretrainedConfig.

Raises

UserWarning – If KV-cache quantization is requested together with sharded KV caching.

class easydel.infra.base_config.EasyMethod(TRAIN: str = 'train', SERVE: str = 'serve', EVAL: str = 'serve', CONVERT: str = 'convert')[source]#

Bases: object

Constants for EasyDeL operation modes.

Defines the different modes in which EasyDeL models can operate.

TRAIN#

Training mode for model optimization.

Type

str

SERVE#

Serving mode for inference.

Type

str

EVAL#

Evaluation mode (alias for serve).

Type

str

CONVERT#

Conversion mode for model format changes.

Type

str

CONVERT: str = 'convert'#
EVAL: str = 'serve'#
SERVE: str = 'serve'#
TRAIN: str = 'train'#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

easydel.infra.base_config.extract_commit_hash(resolved_file: str | None, commit_hash: str | None) str | None[source]#

Extracts the git commit hash from a HuggingFace cache file path.

When models are downloaded from HuggingFace Hub, they’re cached locally with paths containing the commit hash in the format: …/snapshots/<commit_hash>/…. This function extracts that hash for tracking model versions.

Parameters
  • resolved_file – Path to the resolved cache file. If None or if commit_hash is already provided, returns the existing commit_hash immediately.

  • commit_hash – Existing commit hash if already known. If provided, this function returns it without parsing the file path.

Returns

  • No file path is provided

  • The path doesn’t contain a snapshots directory

  • The extracted string doesn’t match git commit hash format

Return type

The extracted commit hash string (40-character hex), or None if

Example

>>> path = "/cache/snapshots/abc123def456.../model.safetensors"
>>> commit_hash = extract_commit_hash(path, None)
>>> # commit_hash = "abc123def456..." if valid
easydel.infra.base_config.set_attrs_smartly(self, attr_name: str, default: Any, new_attr: Any)[source]#

Sets attributes intelligently with default values and optional updates.

This helper function provides smart attribute management: 1. If the attribute doesn’t exist, sets it to the default value 2. If new_attr is provided (not NOT_GIVEN sentinel), updates the attribute

This pattern allows configuration classes to have default values while supporting explicit overrides through constructor parameters.

Parameters
  • self – The object to set the attribute on.

  • attr_name – Name of the attribute to set/update.

  • default – Default value to use if the attribute doesn’t exist yet.

  • new_attr – New value to set if provided. If equal to NOT_GIVEN sentinel, the existing value (or default) is preserved.

Example

>>> config = SomeConfig()
>>> set_attrs_smartly(config, "hidden_size", 768, 1024)
>>> # config.hidden_size = 1024 (updated)
>>>
>>> set_attrs_smartly(config, "num_layers", 12, NOT_GIVEN)
>>> # config.num_layers = 12 (default, since NOT_GIVEN)