easydel.layers.operations._operation_meta#
- class easydel.layers.operations._operation_meta.AttnShardingRules(query3d: jax.sharding.PartitionSpec, query: jax.sharding.PartitionSpec, key: jax.sharding.PartitionSpec, value: jax.sharding.PartitionSpec, bias: jax.sharding.PartitionSpec, mask: jax.sharding.PartitionSpec, output: jax.sharding.PartitionSpec, q_segment_ids: jax.sharding.PartitionSpec, kv_segment_ids: jax.sharding.PartitionSpec, softmax_aux: jax.sharding.PartitionSpec | None)[source]#
Bases:
NamedTupleNamed tuple containing JAX PartitionSpecs for all attention tensors.
- query3d#
Sharding for a 3d query tensor which is [b, h, d].
- query#
Sharding for query tensor.
- key#
Sharding for key tensor.
- value#
Sharding for value tensor.
- bias#
Sharding for attention bias tensor.
- mask#
Sharding for attention mask tensor.
- output#
Sharding for attention output tensor.
- q_segment_ids#
Sharding for query segment IDs (for packed sequences).
- kv_segment_ids#
Sharding for key/value segment IDs (for packed sequences).
- softmax_aux#
Optional sharding for 2D softmax auxiliary outputs (e.g., LSE, max).
- Type
jax.sharding.PartitionSpec | None
- bias: PartitionSpec#
Alias for field number 4
- key: PartitionSpec#
Alias for field number 2
- kv_segment_ids: PartitionSpec#
Alias for field number 8
- mask: PartitionSpec#
Alias for field number 5
- output: PartitionSpec#
Alias for field number 6
- q_segment_ids: PartitionSpec#
Alias for field number 7
- query: PartitionSpec#
Alias for field number 1
- query3d: PartitionSpec#
Alias for field number 0
- softmax_aux: jax.sharding.PartitionSpec | None#
Alias for field number 9
- value: PartitionSpec#
Alias for field number 3
- class easydel.layers.operations._operation_meta.OperationMetadata(runtime_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType], runtime_softmax_dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = None, sequence_axis_name: str = <eformer.common_types._Empty object>, platform: enum.Enum | str = <eformer.common_types._Empty object>, backend: enum.Enum | str = <eformer.common_types._Empty object>, partition_axis: ~eformer.escale.partition.manager.PartitionAxis = <eformer.common_types._Empty object>, partition_manager: ~eformer.escale.partition.manager.PartitionManager = <eformer.common_types._Empty object>, base_config: object | None = None, operation_configs: dict[str, object] | None = None, _stored_mesh: jax._src.mesh.Mesh | None = <eformer.common_types._Empty object>)[source]#
Bases:
objectHolds configuration, context, and metadata for attention operations.
This class centralizes various parameters needed by different attention implementations, facilitating consistent behavior and configuration. It handles default values and can be initialized from an EasyDeLBaseConfig.
- runtime_dtype#
The primary JAX dtype for computations (e.g., q, k, v).
- Type
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]
- runtime_softmax_dtype#
Optional JAX dtype for the softmax computation, allowing for higher precision if needed (e.g., float32).
- Type
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]]
- sequence_axis_name#
The name used for the sequence axis in JAX parallelism (sharding_axis_names for pjit).
- Type
str
- mesh#
The JAX device mesh for distributed computation. Must be provided or inferred from context.
- platform#
The target hardware platform (e.g., TPU, GPU).
- Type
enum.Enum | str
- backend#
The specific JAX backend being used (e.g., TPU, CUDA, ROCM).
- Type
enum.Enum | str
- partition_axis#
Configuration for partitioning axes in distributed settings. (Likely from eformer.escale).
- Type
eformer.escale.partition.manager.PartitionAxis
- base_config#
An optional reference to the base model configuration object for sourcing default values.
- Type
object | None
- scan_ring_attention#
Boolean flag indicating whether to use ring attention via jax.lax.scan.
- softmax_scale#
The scaling factor applied before the softmax operation. Often 1 / sqrt(head_dim).
- dropout_prob#
The dropout probability applied to attention weights.
- blocksize_q#
Block size for the query sequence dimension in blockwise attention.
- blocksize_k#
Block size for the key/value sequence dimension in blockwise attention.
- blocksize_b#
Block size for the batch dimension in blockwise attention (often 1).
- backend: enum.Enum | str = <eformer.common_types._Empty object>#
- classmethod from_config(config: object) OperationMetadata[source]#
Factory method to create OperationMetadata from an EasyDeLBaseConfig.
- Parameters
config – The base configuration object (e.g., model config).
softmax_scale – The attention softmax scaling factor. Usually calculated based on head dimension.
dropout_prob – The attention dropout probability. Defaults to 0.0.
- Returns
An initialized OperationMetadata instance.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- get_operation_config(impl_name: str) object | None[source]#
Get ejkernel config for a specific operation by its registered name.
- Parameters
impl_name – The operation implementation name (must match OperationRegistry). Valid names: - “flash_attn2”: Flash attention 2 implementation - “ring”: Ring attention - “blocksparse”: Block sparse attention - “ragged_page_attention_v2”: Ragged page attention v2 - “ragged_page_attention_v3”: Ragged page attention v3 - “sdpa”: Scaled dot product attention - “vanilla”: Vanilla attention
- Returns
The operation config if set, otherwise None (which enables ejkernel autotune).
Example
>>> cfg = metadata.get_operation_config("flash_attn2") >>> if cfg is not None: ... # Use explicit config ... flash_attention(..., cfg=cfg) >>> else: ... # Use autotune ... flash_attention(..., cfg=None)
- get_shardings(mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], layout: Literal['bthd', 'bhtd', 'thd'] = 'bthd', qkv_mni_sharding: bool = False, softmax_aux: jax.Array | None = None) AttnShardingRules[source]#
Generates JAX PartitionSpecs for attention tensors based on runtime mode.
- Parameters
mode – Runtime mode (e.g., training, inference) for partition resolution.
layout – Tensor layout format - “bthd” (batch, time, heads, dim) or “bhtd” (batch, heads, time, dim).
qkv_mni_sharding – If True, use HEAD/HEAD_DIM for K/V instead of KV_HEAD/KV_HEAD_DIM. Useful for multi-head attention (MHA) vs grouped-query attention (GQA/MQA).
softmax_aux_2d – If True, create sharding for 2D softmax auxiliary outputs (e.g., log-sum-exp, max values) with shape [batch, num_heads].
- Returns
- Named tuple containing PartitionSpecs for all attention tensors:
query, key, value: Main attention tensors
bias: Attention bias tensor
mask: Attention mask tensor
output: Attention output tensor
q_segment_ids: Query segment IDs (for packed sequences)
kv_segment_ids: Key/value segment IDs (for packed sequences)
softmax_aux: Optional 2D softmax auxiliary output sharding
- Return type
- property mesh: jax._src.mesh.Mesh | None#
Get current mesh from base_config if available, otherwise return stored mesh.
- partition_axis: PartitionAxis = <eformer.common_types._Empty object>#
- partition_manager: PartitionManager = <eformer.common_types._Empty object>#
- platform: enum.Enum | str = <eformer.common_types._Empty object>#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- sequence_axis_name: str = <eformer.common_types._Empty object>#
- set_attrs_carefully(attr_name: str, default: Any | None, pickup_name: str | None = None, use_base_config: bool = True) None[source]#
Internal helper to set an attribute if it’s not already set (or is Ellipsis).
Optionally retrieves the value from self.base_config using pickup_name (or attr_name if pickup_name is None).
- Parameters
attr_name – The name of the attribute to set on self.
default – The default value to use if not found in base_config or if use_base_config is False.
pickup_name – The name of the attribute to look for in base_config. Defaults to attr_name.
use_base_config – Whether to attempt retrieving the value from base_config.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.