easydel.infra.base_config#
Base configuration classes for EasyDeL models.
This module provides the foundational configuration system for all EasyDeL models, extending HuggingFace’s PretrainedConfig with EasyDeL-specific features like attention mechanisms, quantization, gradient checkpointing, and hardware optimization.
- Classes:
EasyDeLBaseConfig: Main configuration class with all EasyDeL features EasyDeLBaseConfigDict: Simplified dictionary-based configuration
- Key Features:
Multiple attention mechanism support (flash, ring, etc.)
Quantization configuration
Gradient checkpointing policies
Hardware abstraction and optimization
RoPE (Rotary Position Embedding) configuration
Custom kernel support
Example
>>> from easydel.infra import EasyDeLBaseConfig
>>> config = EasyDeLBaseConfig(
... hidden_size=768,
... num_attention_heads=12,
... attention_mechanism="flash",
... gradient_checkpointing_policy="",
... use_hardware_abstraction=True
... )
- class easydel.infra.base_config.EasyDeLBaseConfig(sharding_axis_dims: tp.Sequence[int] = (1, -1, 1, 1, 1), sharding_dcn_axis_dims: tp.Sequence[int] | None = None, sharding_axis_names: tp.Sequence[str] = ('dp', 'fsdp', 'ep', 'tp', 'sp'), attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS = 'vanilla', decode_attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS = None, blocksize_k: int = 128, blocksize_q: int = 128, blocksize_b: int = 1, moe_tiling_size_batch: int = 4, moe_tiling_size_seqlen: int = 128, moe_tiling_size_dim: int = 128, partition_axis: PartitionAxis = PartitionAxis(data_parallel_axis='dp', fully_sharded_data_parallel_axis='fsdp', tensor_parallel_axis='tp', sequence_parallel_axis='sp', expert_parallel_axis='ep', batch_axis=('fsdp', 'dp'), sequence_axis='sp', query_sequence_axis='sp', head_axis='tp', kv_head_axis='tp', key_sequence_axis='sp', hidden_state_axis='tp', mlp_intermediate_axis='tp', vocab_axis='tp', expert_axis='ep', expert_gate_axis=None, attention_dim_axis=None, attention_kv_dim_axis=None, bias_head_sequence_axis=None, bias_key_sequence_axis=None, decode_batch_axis=('fsdp', 'dp'), decode_query_sequence_axis=None, decode_head_axis='tp', decode_kv_head_axis='tp', decode_key_sequence_axis='sp', decode_attention_dim_axis=None, decode_attention_kv_dim_axis=None), use_sharded_kv_caching: bool = False, use_sharding_constraint: bool = False, backend: EasyDeLBackends | None = None, platform: EasyDeLPlatforms | None = None, easy_method: tp.Literal['train', 'serve', 'convert'] = 'train', bits: int | None = None, scan_ring_attention: bool = True, scan_attention_layers: bool = False, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, sequence_axis_name: str = 'sp', gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, gradient_checkpointing_targets: list[AVAILABLE_GRADIENT_CHECKPOINT_TARGETS] | None = None, precompute_masks: bool = True, kv_cache_quantization_config: EasyDeLQuantizationConfig | None = None, quantization_config: EasyDeLQuantizationConfig | None = None, kv_cache_sharding_sequence_axis_name: str | tuple[str, ...] = 'sp', flash_attention_backward_pass_impl: tp.Literal['triton', 'xla'] = 'triton', attn_dtype: jnp.dtype = <class 'jax.numpy.bfloat16'>, kvdtype: jnp.dtype | None = None, attn_softmax_dtype: jnp.dtype = <class 'jax.numpy.float32'>, fcm_max_ratio: float = 0.0, fcm_min_ratio: float = 0.0, hardware_abstraction: bool = False, pallas_m_block_size: int = 128, pallas_k_block_size: int = 128, pallas_n_block_size: int = 128, moe_method: AVAILABLE_MOE_METHODS = 'fused_moe', moe_force_xla_gmm: bool = False, use_expert_tensor_mode: bool = False, use_ring_of_experts: bool = False, fsdp_is_ep_bound: bool = True, sp_is_ep_bound: bool = True, operation_configs: dict[str, BaseOperationConfig] | None = None, **kwargs)[source]#
Bases:
PretrainedConfigBase configuration shared across EasyDeL models.
Extends transformers.PretrainedConfig with distributed sharding metadata, attention kernel selection, quantization knobs, RoPE helpers, and hardware abstraction flags used for both training and serving.
- Parameters
sharding_axis_dims – Parallelism sizes for
(dp, fsdp, ep, tp, sp).-1consumes all remaining devices. Defaults to(1, -1, 1, 1, 1).sharding_dcn_axis_dims – Optional mesh sizes for DCN slices when running multi-host or multi-slice setups.
sharding_axis_names – Logical mesh axis names, defaults to
("dp", "fsdp", "ep", "tp", "sp").attn_mechanism – Attention implementation to use during training/forward passes.
decode_attn_mechanism – Attention implementation to use during decoding (falls back to
attn_mechanismif left asNone).blocksize_k – Key block size for attention kernels. Defaults to
128.blocksize_q – Query block size for attention kernels. Defaults to
128.blocksize_b – Batch/block size used by some attention backends. Defaults to
1.moe_tiling_size_batch – Batch tiling used by MoE kernels. Defaults to
4.moe_tiling_size_seqlen – Sequence length tiling for MoE kernels. Defaults to
128.moe_tiling_size_dim – Hidden dimension tiling for MoE kernels. Defaults to
128.partition_axis – PartitionAxis describing how logical axes map to the mesh.
use_sharded_kv_caching – Whether to shard KV cache placement instead of replicating.
use_sharding_constraint – Insert explicit sharding constraints during model build.
backend – Explicit JAX backend (falls back to
jax.default_backend()).platform – Platform hint for kernel selection (defaults to
"triton"on GPU, otherwise"jax").easy_method – Workflow context (
"train","serve", or"convert").bits – Optional quantization bit width for weights.
scan_ring_attention – Use scanning for ring attention implementations.
scan_attention_layers – Apply scan to attention blocks to save memory.
use_scan_mlp – Apply scan to MLP blocks.
scan_mlp_chunk_size – Chunk size when scanning MLPs. Defaults to
1024.sequence_axis_name – Name of the sequence/attention axis. Defaults to
"sp".gradient_checkpointing – Gradient checkpointing policy enum/string.
gradient_checkpointing_targets – Optional list of target names to include or exclude when using selective checkpointing policies.
precompute_masks – Whether to precompute and cache causal masks on the mesh.
kv_cache_quantization_config – Quantization config for KV cache tensors. Pass
Noneto disable.quantization_config – Quantization config for linear layers. Pass
Noneto disable.kv_cache_sharding_sequence_axis_name – Axis (or axes) used when sharding the KV cache.
flash_attention_backward_pass_impl – Backward kernel for flash attention (
"triton"or"xla"). Defaults to"triton".attn_dtype – Attention activation dtype. Defaults to
jnp.bfloat16.kvdtype – KV cache dtype. Defaults to
attn_dtypewhenNone.attn_softmax_dtype – Softmax computation dtype. Defaults to
jnp.float32.fcm_max_ratio – Maximum ratio used when sampling forgetful causal masks.
fcm_min_ratio – Minimum ratio used when sampling forgetful causal masks.
hardware_abstraction – Enable EasyDeL hardware abstraction and custom kernels.
pallas_m_block_size – Matmul M dimension block size for Pallas kernels.
pallas_k_block_size – Matmul K dimension block size for Pallas kernels.
pallas_n_block_size – Matmul N dimension block size for Pallas kernels.
moe_method – Mixture-of-experts implementation to use.
moe_force_xla_gmm – Force XLA GMM kernels for MoE even when fused kernels exist.
use_ring_of_experts – Whether to dispatch experts with a ring topology.
use_expert_tensor_mode – Treat experts as an additional tensor-parallel axis.
fsdp_is_ep_bound – Fold the FSDP axis into the expert axis when building expert meshes.
sp_is_ep_bound – Fold the sequence-parallel axis into the expert axis when building expert meshes.
**kwargs – Forwarded to PretrainedConfig.
- Raises
UserWarning – If KV-cache quantization is requested together with sharded KV caching.
- add_basic_configurations(sharding_axis_dims: ~typing.Sequence[int] = <eformer.common_types._Empty object>, sharding_dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = <eformer.common_types._Empty object>, sharding_axis_names: ~typing.Sequence[str] = <eformer.common_types._Empty object>, attn_mechanism: ~typing.Literal['auto', 'vanilla', 'flash_attn2', 'blocksparse', 'ring', 'cudnn', 'blockwise', 'sdpa', 'autoregressive_decodeattn', 'ragged_page_attention_v2', 'ragged_page_attention_v3', 'page_attention'] = <eformer.common_types._Empty object>, decode_attn_mechanism: ~typing.Literal['auto', 'vanilla', 'flash_attn2', 'blocksparse', 'ring', 'cudnn', 'blockwise', 'sdpa', 'autoregressive_decodeattn', 'ragged_page_attention_v2', 'ragged_page_attention_v3', 'page_attention'] = <eformer.common_types._Empty object>, blocksize_k: int = <eformer.common_types._Empty object>, blocksize_q: int = <eformer.common_types._Empty object>, blocksize_b: int = <eformer.common_types._Empty object>, moe_tiling_size_batch: int = <eformer.common_types._Empty object>, moe_tiling_size_seqlen: int = <eformer.common_types._Empty object>, moe_tiling_size_dim: int = <eformer.common_types._Empty object>, partition_axis: ~eformer.escale.partition.manager.PartitionAxis = <eformer.common_types._Empty object>, use_sharded_kv_caching: bool = <eformer.common_types._Empty object>, backend: easydel.infra.etils.EasyDeLBackends | None = <eformer.common_types._Empty object>, platform: easydel.infra.etils.EasyDeLPlatforms | None = <eformer.common_types._Empty object>, easy_method: ~typing.Literal['train', 'serve', 'convert'] = <eformer.common_types._Empty object>, bits: int | None = <eformer.common_types._Empty object>, scan_ring_attention: bool = <eformer.common_types._Empty object>, scan_attention_layers: bool = <eformer.common_types._Empty object>, use_sharding_constraint: bool = <eformer.common_types._Empty object>, use_scan_mlp: bool = <eformer.common_types._Empty object>, scan_mlp_chunk_size: int = <eformer.common_types._Empty object>, sequence_axis_name: str = <eformer.common_types._Empty object>, gradient_checkpointing: ~easydel.infra.etils.EasyDeLGradientCheckPointers = <eformer.common_types._Empty object>, gradient_checkpointing_targets: list[typing.Literal['attn_dense', 'attn_key', 'attn_key_value', 'attn_output', 'attn_qkv', 'attn_query', 'attn_receptance', 'attn_value', 'attn_weights', 'embeddings', 'layer_output', 'lm_head_output', 'mlp_down', 'mlp_gate', 'mlp_output', 'mlp_up', 'model_output', 'moe_expert_output', 'moe_gate_logits', 'moe_output', 'moe_router_logits', 'normed_input', 'residual']] | None = <eformer.common_types._Empty object>, precompute_masks: bool = <eformer.common_types._Empty object>, kv_cache_quantization_config: easydel.layers.quantization.quantizers.EasyDeLQuantizationConfig | None = <eformer.common_types._Empty object>, quantization_config: easydel.layers.quantization.quantizers.EasyDeLQuantizationConfig | None = <eformer.common_types._Empty object>, kv_cache_sharding_sequence_axis_name: str | tuple[str, ...] = <eformer.common_types._Empty object>, flash_attention_backward_pass_impl: ~typing.Literal['triton', 'xla'] = <eformer.common_types._Empty object>, attn_dtype: ~numpy.dtype = <eformer.common_types._Empty object>, kvdtype: numpy.dtype | None = <eformer.common_types._Empty object>, attn_softmax_dtype: ~numpy.dtype = <eformer.common_types._Empty object>, hardware_abstraction: bool = <eformer.common_types._Empty object>, pallas_m_block_size: int = <eformer.common_types._Empty object>, pallas_k_block_size: int = <eformer.common_types._Empty object>, pallas_n_block_size: int = <eformer.common_types._Empty object>, moe_method: ~typing.Literal['fused_moe', 'standard_moe', 'dense_moe'] = <eformer.common_types._Empty object>, moe_force_xla_gmm: bool = <eformer.common_types._Empty object>, use_ring_of_experts: bool = <eformer.common_types._Empty object>, use_expert_tensor_mode: bool = <eformer.common_types._Empty object>, fsdp_is_ep_bound: bool = <eformer.common_types._Empty object>, sp_is_ep_bound: bool = <eformer.common_types._Empty object>, **kwargs)[source]#
Populate baseline EasyDeL attributes on an existing config instance.
Each argument mirrors the constructor but is optional: passing NOT_GIVEN leaves any existing attribute untouched, while a provided value overwrites the current setting. If an attribute is missing entirely, a sensible default is applied via set_attrs_smartly. This helper is used by derived configs (and their sub_configs) to keep sharding/attention/quantization knobs in sync without re-implementing initialization logic.
- Parameters
sharding_axis_dims – Fallback mesh sizes for
(dp, fsdp, ep, tp, sp), defaulting to(1, -1, 1, 1, 1).sharding_dcn_axis_dims – Optional DCN mesh sizes (default
None).sharding_axis_names – Mesh axis labels, default
("dp", "fsdp", "ep", "tp", "sp").attn_mechanism – Attention mechanism to use (default
"vanilla").decode_attn_mechanism – Optional decode-time attention mechanism.
blocksize_k – Attention key block size, default
512when unset.blocksize_q – Attention query block size, default
512when unset.blocksize_b – Batch/block size used by attention kernels (default
1).moe_tiling_size_batch – Batch tiling for MoE kernels (default
4).moe_tiling_size_seqlen – Sequence tiling for MoE kernels (default
128).moe_tiling_size_dim – Hidden-dim tiling for MoE kernels (default
128).partition_axis – PartitionAxis describing logical mesh layout (default
PartitionAxis()).use_sharded_kv_caching – Whether to shard KV caches (default
False).backend – Backend string, default
None(falls back to JAX default).platform – Platform hint, default
"jax".easy_method – EasyDeL execution mode, default
EasyMethod.TRAIN.bits – Optional quantization bit width, default
None.scan_ring_attention – Enable scan for ring attention (default
True).scan_attention_layers – Enable scan for attention blocks (default
True).use_sharding_constraint – Insert sharding constraints (default
False).use_scan_mlp – Enable scan for MLPs (default
False).scan_mlp_chunk_size – Chunk size for scanned MLPs (default
1024).sequence_axis_name – Label for the sequence/attention axis (default
"sp").gradient_checkpointing – Gradient checkpointing policy (default
EasyDeLGradientCheckPointers.NONE).gradient_checkpointing_targets – Optional list of checkpoint targets to include/exclude (default
None).precompute_masks – Whether to precompute and cache masks (default
True).kv_cache_quantization_config – KV cache quantization config (default
None= no quantization).quantization_config – Linear-layer quantization config (default
None= no quantization).kv_cache_sharding_sequence_axis_name – Axis name(s) for KV cache sharding (default
"sp").flash_attention_backward_pass_impl – Backward kernel for flash attention (default
"triton").attn_dtype – Attention activation dtype (default
jnp.float32).kvdtype – KV cache dtype (defaults to attn_dtype when unset).
attn_softmax_dtype – Softmax computation dtype (default
jnp.float32).hardware_abstraction – Toggle EasyDeL hardware abstraction (default
DEFAULT_HARDWARE_ABSTRACTION).pallas_m_block_size – Pallas matmul M block size (default
DEFAULT_PALLAS_M_BLOCK_SIZE).pallas_k_block_size – Pallas matmul K block size (default
DEFAULT_PALLAS_K_BLOCK_SIZE).pallas_n_block_size – Pallas matmul N block size (default
DEFAULT_PALLAS_N_BLOCK_SIZE).moe_method – MoE implementation to use (default
DEFAULT_MOE_METHOD).moe_force_xla_gmm – Force XLA GMM kernels for MoE (default
False).use_ring_of_experts – Dispatch experts with a ring topology (default
RING_EXPERTS).use_expert_tensor_mode – Treat experts as a tensor-parallel axis (default
EXPERT_TP_MODE).fsdp_is_ep_bound – Fold FSDP into the expert axis when building expert meshes.
sp_is_ep_bound – Fold sequence-parallel into the expert axis when building expert meshes.
**kwargs – Extra attributes to attach to this config and any defined
sub_configs.
- attach_custom_arguments(**kwargs)[source]#
Attaches custom arguments as attributes to the configuration.
- Parameters
**kwargs – Arbitrary key-value pairs to attach as attributes.
- property auto_expert_mesh: Mesh#
Get the mesh for expert parallelism with automatic axis types.
Similar to expert_mesh, but uses jax.sharding.AxisType.Auto for all axes, allowing JAX to automatically determine the optimal sharding strategy based on the computation graph.
- Returns
- A mesh with auto axis types configured for
expert parallelism with (dp, ep, tp) axis ordering.
- Return type
- static create_mesh(sharding_axis_dims: Sequence[int] = (1, -1, 1, 1, 1), sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'ep', 'tp', 'sp'), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, process_is_granule: bool = False, should_sort_granules_by_key: bool = True, allow_split_physical_axes: bool = True, backend: str | None = None)[source]#
Creates a JAX device mesh for distributed model execution.
This function constructs a multi-dimensional mesh of devices that defines how model parameters and computations are distributed across hardware. The mesh axes correspond to different parallelism strategies:
dp (data parallel): Replicate model across data batches
fsdp (fully sharded data parallel): Shard parameters and optimizer states
ep (expert parallel): Distribute experts in MoE models
tp (tensor parallel): Partition individual weight matrices
sp (sequence parallel): Split sequence dimension across devices
- Parameters
sharding_axis_dims – Size of each parallelism dimension. Use -1 to automatically fill remaining devices. Default: (1, -1, 1, 1, 1) means all devices go to FSDP axis.
sharding_axis_names – Names for each mesh axis. Must match length of sharding_axis_dims. Default: (“dp”, “fsdp”, “ep”, “tp”, “sp”).
sharding_dcn_axis_dims – Dimensions for data center network (DCN) sharding. Used for multi-host/multi-slice setups. Default: None.
process_is_granule – Deprecated parameter, not used.
should_sort_granules_by_key – Whether to sort device granules by key for deterministic ordering. Default: True.
allow_split_physical_axes – Whether to allow splitting physical device axes when mapping to logical mesh axes. Default: True.
backend – Backend platform to create mesh for (‘gpu’, ‘tpu’, etc.). If None or empty string, uses default backend.
- Returns
A JAX Mesh object configured for distributed execution with the specified parallelism dimensions.
Example
>>> # Create mesh with DP=4, TP=2 on 8 GPUs >>> mesh = EasyDeLBaseConfig.create_mesh( ... sharding_axis_dims=(4, 1, 1, 2, 1), ... sharding_axis_names=("dp", "fsdp", "ep", "tp", "sp"), ... ) >>> # mesh.shape = {'dp': 4, 'fsdp': 1, 'ep': 1, 'tp': 2, 'sp': 1}
- property expert_abstract_mesh: AbstractMesh#
Get the abstract mesh descriptor for expert parallelism.
Returns an abstract mesh that matches the expert_mesh axis sizes and names. Abstract meshes are lightweight representations used for sharding specification without device assignment.
- Returns
- An abstract mesh descriptor with the same
axis configuration as expert_mesh.
- Return type
jax.sharding.AbstractMesh
- property expert_mesh: Mesh#
Get the mesh configuration for expert parallelism.
Creates a mesh with expert-parallel axes folded according to the fsdp_is_ep_bound and sp_is_ep_bound configuration flags. This mesh is used for MoE (Mixture of Experts) models to distribute experts across devices with explicit axis types.
- Returns
- A mesh with explicit axis types configured for
expert parallelism with (dp, ep, tp) axis ordering.
- Return type
- classmethod from_pretrained(pretrained_model_name_or_path: str | os.PathLike, cache_dir: str | os.PathLike | None = None, force_download: bool = False, local_files_only: bool = False, token: str | bool | None = None, revision: str = 'main', **kwargs) PretrainedConfig[source]#
Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.
- Parameters
pretrained_model_name_or_path (str or os.PathLike) –
This can be either:
a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.
a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.
a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.
cache_dir (str or os.PathLike, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.
force_download (bool, optional, defaults to False) – Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.
resume_download – Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
proxies (Dict[str, str], optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.
token (str or bool, optional) – The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
revision (str, optional, defaults to “main”) –
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.
<Tip>
To test a pull request you made on the Hub, you can pass `revision=”refs/pr/<pr_number>”.
</Tip>
return_unused_kwargs (bool, optional, defaults to False) –
If False, then this function returns just the final configuration object.
If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.
subfolder (str, optional, defaults to “”) – In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.
kwargs (Dict[str, tp.Any], optional) – The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.
- Returns
The configuration object instantiated from this pretrained model.
- Return type
[PretrainedConfig]
Examples:
>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a >>> # derived class: BertConfig >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased" >>> ) # Download configuration from huggingface.co and cache. >>> config = BertConfig.from_pretrained( ... "./test/saved_model/" >>> ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')* >>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json") >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", output_attentions=True, foo=False >>> ) >>> assert config.output_attentions == True >>> config, unused_kwargs = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", ... output_attentions=True, ... foo=False, ... return_unused_kwargs=True, >>> ) >>> assert config.output_attentions == True >>> assert unused_kwargs == {"foo": False}
- get_axis_dims() Sequence[int][source]#
Returns the device mesh axis dimensions for parallelism.
- Returns
Sequence of integers specifying the size of each parallelism axis. Typically (dp_size, fsdp_size, ep_size, tp_size, sp_size). Value of -1 means “use all remaining devices for this axis”.
Example
>>> config.sharding_axis_dims = (2, 4, 1, 1, 1) >>> dims = config.get_axis_dims() >>> # dims = (2, 4, 1, 1, 1) - 2 data parallel, 4 FSDP, rest replicated
- get_axis_names() Sequence[str][source]#
Returns the logical names for each device mesh axis.
- Returns
Sequence of strings naming each parallelism axis. Typically (“dp”, “fsdp”, “ep”, “tp”, “sp”) for data parallel, fully sharded data parallel, expert parallel, tensor parallel, and sequence parallel respectively.
Example
>>> names = config.get_axis_names() >>> # names = ('dp', 'fsdp', 'ep', 'tp', 'sp')
- get_backend() str[source]#
Returns the JAX backend platform being used.
Retrieves the configured backend (e.g., ‘gpu’, ‘tpu’, ‘cpu’), or falls back to the default JAX backend if not explicitly set.
- Returns
‘gpu’, ‘tpu’, ‘cpu’.
- Return type
Backend platform string. Common values
Example
>>> config = EasyDeLBaseConfig(backend='gpu') >>> config.get_backend() 'gpu' >>> >>> # With no backend set, returns JAX default >>> config = EasyDeLBaseConfig(backend='') >>> config.get_backend() # Might return 'gpu', 'tpu', etc.
- get_basic_causal_mask(*args, **kwargs)[source]#
Gets or creates the basic causal attention mask.
Creates a causal mask for the maximum position embeddings and places it on the appropriate device with sharding.
- Returns
ModuleCaches containing the causal mask, or False if masks are not precomputed.
- get_basic_frequencies(head_size: int | None = None, rotary_dim: int | None = None, base: float | None = None) ModuleCaches[source]#
Compute frequencies for rotary embeddings placed on the configured mesh.
- Parameters
head_size – Attention head size (defaults to
self.head_dim).rotary_dim – Number of rotary dimensions (defaults to
head_size).base – Optional base for frequency computation (defaults to
self.rope_theta).
- Returns
ModuleCaches containing the frequencies sharded with NamedSharding.
- get_basic_inv_frequencies(head_size: int | None = None, rotary_dim: int | None = None, base: float | None = None, partial_rotary_factor: float = 1.0) ModuleCaches[source]#
Compute inverse frequencies for rotary embeddings.
- Parameters
head_size – Attention head size (defaults to
self.head_dim).rotary_dim – Number of rotary dimensions (defaults to
head_size).base – Optional base for frequency computation (defaults to
self.rope_theta).partial_rotary_factor – Ratio of the head dimension to apply RoPE to.
- Returns
ModuleCaches wrapping the computed frequency tensor.
- get_basic_rope(dtype: Union[Array, ndarray, bool, number], head_size: int, rotary_dim: int | None = None, is_neox_style: bool = True, base: float | None = None)[source]#
Return a rotary position embedding function configured for this model.
- Parameters
dtype – Target dtype for the generated embeddings.
head_size – Attention head size used to derive the rotary dimension.
rotary_dim – Number of rotary dimensions (defaults to
head_size).is_neox_style – Whether to generate NeoX-style rotary embeddings.
base – Optional base used for frequency computation (defaults to
self.rope_theta).
- Returns
Callable from get_rope ready to be applied to query/key tensors.
- get_fcm_mask(batch_size, seq_length, deterministic: bool)[source]#
Generates a Forgetful Causal Mask (FCM) for training.
FCM randomly drops causal constraints during training to improve model robustness. Only applied in non-deterministic mode.
- Parameters
batch_size – Number of sequences in the batch.
seq_length – Length of each sequence.
deterministic – If True, returns None (no FCM applied).
- Returns
Boolean mask array or None if deterministic or FCM not configured.
- get_mask_details() dict[int, AttnMaskDetail] | None[source]#
Get attention mask details for each layer.
Retrieves layer-specific attention mask configurations, which is particularly useful for models with heterogeneous attention patterns (e.g., models using different attention types per layer like sliding window attention in some layers and full attention in others).
- Returns
- A dictionary mapping layer indices
to their corresponding AttnMaskDetail configurations, or None if the model doesn’t define layer-specific mask types.
- Return type
dict[int, AttnMaskDetail] | None
- get_partition_rules(*args, **kwargs)[source]#
Gets the parameter sharding partition rules for the model.
Partition rules define how model parameters should be sharded across the device mesh. Each rule maps a parameter name pattern (regex) to a PartitionSpec that specifies which mesh axes the parameter dimensions should be distributed across.
This method must be implemented by model-specific configuration classes.
- Parameters
*args – Positional arguments (model-specific).
**kwargs – Keyword arguments (model-specific).
- Returns
Tuple of (pattern, PartitionSpec) pairs defining how to shard parameters. For example: ((“model/embed.*”, PartitionSpec(“tp”, None)),
(“model/layers/d+/attn/.*”, PartitionSpec(None, “tp”)))
- Raises
NotImplementedError – This base class does not provide default partition rules. Subclasses must implement this method.
Example
>>> class MyModelConfig(EasyDeLBaseConfig): ... def get_partition_rules(self): ... return ( ... ("embed.*", PartitionSpec("tp", None)), ... ("attn.*", PartitionSpec(None, "tp", None)), ... ("mlp.*", PartitionSpec(None, "tp")), ... )
- property granted_freq_max_position_embedding: int#
Return the max position embedding allowed for frequency-based caches.
- property granted_mask_max_position_embedding: int#
Return the max position embedding allowed for mask precomputation.
- jax_mesh()[source]#
Deprecated method for getting the JAX mesh.
- Deprecated:
Use mesh property or get_mesh() method instead.
- Returns
JAX device mesh.
- property mesh#
Gets or creates the JAX device mesh for this configuration.
This property lazily constructs a device mesh from the configuration’s sharding parameters. Once created, the mesh is cached for reuse. The mesh can be explicitly set using set_model_mesh() to override the auto-generated one.
The mesh is constructed from: - sharding_axis_dims: Device counts per axis - sharding_axis_names: Logical names for each axis - sharding_dcn_axis_dims: Multi-host configuration (if applicable) - Various granule sorting and axis splitting options
- Returns
JAX Mesh object defining the device topology for distributed execution. The mesh axes correspond to parallelism strategies (dp, fsdp, ep, tp, sp).
Note
If a custom mesh was set via set_model_mesh(), that mesh is returned instead of creating a new one.
Example
>>> config = EasyDeLBaseConfig( ... sharding_axis_dims=(2, 1, 1, 4, 1), ... sharding_axis_names=("dp", "fsdp", "ep", "tp", "sp"), ... ) >>> mesh = config.mesh >>> # mesh.shape = {'dp': 2, 'fsdp': 1, 'ep': 1, 'tp': 4, 'sp': 1}
- property partition_manager: PartitionManager#
Gets the partition manager for this configuration.
The PartitionManager handles translation between logical axis names (like ‘dp’, ‘tp’) and their actual configurations in the device mesh. It provides utilities for resolving partition specifications and managing distributed execution.
- Returns
PartitionManager instance configured with this config’s partition_axis. If partition_axis is None, creates a default PartitionAxis() first.
Example
>>> config = EasyDeLBaseConfig() >>> pm = config.partition_manager >>> # Use partition manager to resolve sharding specs >>> spec = pm.resolve(axes=["dp", "tp"], mode="train", shape=(8, 1024))
- read_basics_from_config(config: EasyDeLBaseConfig)[source]#
Reads and applies basic configuration attributes from another config instance.
Copies EasyDeL-specific attributes like sharding, attention mechanism, quantization settings, etc. from the provided config.
- Parameters
config – Source configuration to read attributes from.
- save_pretrained(save_directory: str | os.PathLike | eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath, push_to_hub: bool = False, **kwargs)[source]#
Save a configuration object to the directory save_directory, so that it can be re-loaded using the [~PretrainedConfig.from_pretrained] class method.
- Parameters
save_directory (str or os.PathLike) – Directory where the configuration JSON file will be saved (will be created if it does not exist).
push_to_hub (bool, optional, defaults to False) – Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with repo_id (will default to the name of save_directory in your namespace).
kwargs (Dict[str, Any], optional) – Additional key word arguments passed along to the [~utils.PushToHubMixin.push_to_hub] method.
- set_model_mesh(mesh: Mesh)[source]#
Sets a custom mesh for the model, overriding the auto-generated one.
- Parameters
mesh – JAX device mesh to use for this model.
- to_dict() dict[str, Any][source]#
Serialize config to a dictionary while temporarily hiding forbidden types.
- to_diff_dict() dict[str, Any][source]#
Removes all attributes from the configuration that correspond to the default config attributes for better readability, while always retaining the config attribute from the class. Serializes to a Python dictionary.
- Returns
Dictionary of all the attributes that make up this configuration instance.
- Return type
dict[str, Any]
- to_json_file(json_file_path: str | os.PathLike | eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath, use_diff: bool = True)[source]#
Save this instance to a JSON file.
- Parameters
json_file_path (str or os.PathLike) – Path to the JSON file in which this configuration instance’s parameters will be saved.
use_diff (bool, optional, defaults to True) – If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.
- class easydel.infra.base_config.EasyDeLBaseConfigDict[source]#
Bases:
TypedDictBase configuration shared across EasyDeL models.
Extends transformers.PretrainedConfig with distributed sharding metadata, attention kernel selection, quantization knobs, RoPE helpers, and hardware abstraction flags used for both training and serving.
- Parameters
sharding_axis_dims – Parallelism sizes for
(dp, fsdp, ep, tp, sp).-1consumes all remaining devices. Defaults to(1, -1, 1, 1, 1).sharding_dcn_axis_dims – Optional mesh sizes for DCN slices when running multi-host or multi-slice setups.
sharding_axis_names – Logical mesh axis names, defaults to
("dp", "fsdp", "ep", "tp", "sp").attn_mechanism – Attention implementation to use during training/forward passes.
decode_attn_mechanism – Attention implementation to use during decoding (falls back to
attn_mechanismif left asNone).blocksize_k – Key block size for attention kernels. Defaults to
128.blocksize_q – Query block size for attention kernels. Defaults to
128.blocksize_b – Batch/block size used by some attention backends. Defaults to
1.moe_tiling_size_batch – Batch tiling used by MoE kernels. Defaults to
4.moe_tiling_size_seqlen – Sequence length tiling for MoE kernels. Defaults to
128.moe_tiling_size_dim – Hidden dimension tiling for MoE kernels. Defaults to
128.partition_axis – PartitionAxis describing how logical axes map to the mesh.
use_sharded_kv_caching – Whether to shard KV cache placement instead of replicating.
use_sharding_constraint – Insert explicit sharding constraints during model build.
backend – Explicit JAX backend (falls back to
jax.default_backend()).platform – Platform hint for kernel selection (defaults to
"triton"on GPU, otherwise"jax").easy_method – Workflow context (
"train","serve", or"convert").bits – Optional quantization bit width for weights.
scan_ring_attention – Use scanning for ring attention implementations.
scan_attention_layers – Apply scan to attention blocks to save memory.
use_scan_mlp – Apply scan to MLP blocks.
scan_mlp_chunk_size – Chunk size when scanning MLPs. Defaults to
1024.sequence_axis_name – Name of the sequence/attention axis. Defaults to
"sp".gradient_checkpointing – Gradient checkpointing policy enum/string.
gradient_checkpointing_targets – Optional list of target names to include or exclude when using selective checkpointing policies.
precompute_masks – Whether to precompute and cache causal masks on the mesh.
kv_cache_quantization_config – Quantization config for KV cache tensors. Pass
Noneto disable.quantization_config – Quantization config for linear layers. Pass
Noneto disable.kv_cache_sharding_sequence_axis_name – Axis (or axes) used when sharding the KV cache.
flash_attention_backward_pass_impl – Backward kernel for flash attention (
"triton"or"xla"). Defaults to"triton".attn_dtype – Attention activation dtype. Defaults to
jnp.bfloat16.kvdtype – KV cache dtype. Defaults to
attn_dtypewhenNone.attn_softmax_dtype – Softmax computation dtype. Defaults to
jnp.float32.fcm_max_ratio – Maximum ratio used when sampling forgetful causal masks.
fcm_min_ratio – Minimum ratio used when sampling forgetful causal masks.
hardware_abstraction – Enable EasyDeL hardware abstraction and custom kernels.
pallas_m_block_size – Matmul M dimension block size for Pallas kernels.
pallas_k_block_size – Matmul K dimension block size for Pallas kernels.
pallas_n_block_size – Matmul N dimension block size for Pallas kernels.
moe_method – Mixture-of-experts implementation to use.
moe_force_xla_gmm – Force XLA GMM kernels for MoE even when fused kernels exist.
use_ring_of_experts – Whether to dispatch experts with a ring topology.
use_expert_tensor_mode – Treat experts as an additional tensor-parallel axis.
fsdp_is_ep_bound – Fold the FSDP axis into the expert axis when building expert meshes.
sp_is_ep_bound – Fold the sequence-parallel axis into the expert axis when building expert meshes.
**kwargs – Forwarded to PretrainedConfig.
- Raises
UserWarning – If KV-cache quantization is requested together with sharded KV caching.
- class easydel.infra.base_config.EasyMethod(TRAIN: str = 'train', SERVE: str = 'serve', EVAL: str = 'serve', CONVERT: str = 'convert')[source]#
Bases:
objectConstants for EasyDeL operation modes.
Defines the different modes in which EasyDeL models can operate.
- TRAIN#
Training mode for model optimization.
- Type
str
- SERVE#
Serving mode for inference.
- Type
str
- EVAL#
Evaluation mode (alias for serve).
- Type
str
- CONVERT#
Conversion mode for model format changes.
- Type
str
- CONVERT: str = 'convert'#
- EVAL: str = 'serve'#
- SERVE: str = 'serve'#
- TRAIN: str = 'train'#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- easydel.infra.base_config.extract_commit_hash(resolved_file: str | None, commit_hash: str | None) str | None[source]#
Extracts the git commit hash from a HuggingFace cache file path.
When models are downloaded from HuggingFace Hub, they’re cached locally with paths containing the commit hash in the format: …/snapshots/<commit_hash>/…. This function extracts that hash for tracking model versions.
- Parameters
resolved_file – Path to the resolved cache file. If None or if commit_hash is already provided, returns the existing commit_hash immediately.
commit_hash – Existing commit hash if already known. If provided, this function returns it without parsing the file path.
- Returns
No file path is provided
The path doesn’t contain a snapshots directory
The extracted string doesn’t match git commit hash format
- Return type
The extracted commit hash string (40-character hex), or None if
Example
>>> path = "/cache/snapshots/abc123def456.../model.safetensors" >>> commit_hash = extract_commit_hash(path, None) >>> # commit_hash = "abc123def456..." if valid
- easydel.infra.base_config.set_attrs_smartly(self, attr_name: str, default: Any, new_attr: Any)[source]#
Sets attributes intelligently with default values and optional updates.
This helper function provides smart attribute management: 1. If the attribute doesn’t exist, sets it to the default value 2. If new_attr is provided (not NOT_GIVEN sentinel), updates the attribute
This pattern allows configuration classes to have default values while supporting explicit overrides through constructor parameters.
- Parameters
self – The object to set the attribute on.
attr_name – Name of the attribute to set/update.
default – Default value to use if the attribute doesn’t exist yet.
new_attr – New value to set if provided. If equal to NOT_GIVEN sentinel, the existing value (or default) is preserved.
Example
>>> config = SomeConfig() >>> set_attrs_smartly(config, "hidden_size", 768, 1024) >>> # config.hidden_size = 1024 (updated) >>> >>> set_attrs_smartly(config, "num_layers", 12, NOT_GIVEN) >>> # config.num_layers = 12 (default, since NOT_GIVEN)