Source code for easydel.layers.operations._operation_meta

from __future__ import annotations

import dataclasses
import enum
import typing as tp
from typing import NamedTuple

import jax
import jaxtyping
from eformer import common_types
from eformer.escale import PartitionAxis, PartitionManager
from eformer.loggings import get_logger
from eformer.pytree import auto_pytree
from jax import numpy as jnp

if tp.TYPE_CHECKING:
    from ejkernel.modules.operations.configs import BaseOperationConfig

    from easydel.infra.base_config import EasyDeLBaseConfig
    from easydel.infra.etils import EasyDeLBackends, EasyDeLPlatforms
else:
    EasyDeLPlatforms = enum.Enum | str
    EasyDeLBackends = enum.Enum | str
    EasyDeLBaseConfig = object
    BaseOperationConfig = object

logger = get_logger("EasyDeL-OperationOperator")
NOT_GIVEN = common_types.NOT_GIVEN
RUNTIME_MODE_TYPES = common_types.RUNTIME_MODE_TYPES
BATCH = common_types.BATCH
QUERY_LENGTH = common_types.QUERY_LENGTH
KV_LENGTH = common_types.KV_LENGTH
HEAD = common_types.HEAD
KV_HEAD = common_types.KV_HEAD
HEAD_DIM = common_types.HEAD_DIM
KV_HEAD_DIM = common_types.KV_HEAD_DIM
BIAS_HEAD_SEQ = common_types.BIAS_HEAD_SEQ
BIAS_KV_SEQ = common_types.BIAS_KV_SEQ
EMPTY = common_types.EMPTY


[docs]class AttnShardingRules(NamedTuple): """ Named tuple containing JAX PartitionSpecs for all attention tensors. Attributes: query3d: Sharding for a 3d query tensor which is [b, h, d]. query: Sharding for query tensor. key: Sharding for key tensor. value: Sharding for value tensor. bias: Sharding for attention bias tensor. mask: Sharding for attention mask tensor. output: Sharding for attention output tensor. q_segment_ids: Sharding for query segment IDs (for packed sequences). kv_segment_ids: Sharding for key/value segment IDs (for packed sequences). softmax_aux: Optional sharding for 2D softmax auxiliary outputs (e.g., LSE, max). """ query3d: jax.sharding.PartitionSpec query: jax.sharding.PartitionSpec key: jax.sharding.PartitionSpec value: jax.sharding.PartitionSpec bias: jax.sharding.PartitionSpec mask: jax.sharding.PartitionSpec output: jax.sharding.PartitionSpec q_segment_ids: jax.sharding.PartitionSpec kv_segment_ids: jax.sharding.PartitionSpec softmax_aux: jax.sharding.PartitionSpec | None
[docs]@auto_pytree class OperationMetadata: """ Holds configuration, context, and metadata for attention operations. This class centralizes various parameters needed by different attention implementations, facilitating consistent behavior and configuration. It handles default values and can be initialized from an `EasyDeLBaseConfig`. Attributes: runtime_dtype: The primary JAX dtype for computations (e.g., q, k, v). runtime_softmax_dtype: Optional JAX dtype for the softmax computation, allowing for higher precision if needed (e.g., float32). sequence_axis_name: The name used for the sequence axis in JAX parallelism (sharding_axis_names for pjit). mesh: The JAX device mesh for distributed computation. Must be provided or inferred from context. platform: The target hardware platform (e.g., TPU, GPU). backend: The specific JAX backend being used (e.g., TPU, CUDA, ROCM). partition_axis: Configuration for partitioning axes in distributed settings. (Likely from `eformer.escale`). base_config: An optional reference to the base model configuration object for sourcing default values. scan_ring_attention: Boolean flag indicating whether to use ring attention via `jax.lax.scan`. softmax_scale: The scaling factor applied before the softmax operation. Often `1 / sqrt(head_dim)`. dropout_prob: The dropout probability applied to attention weights. blocksize_q: Block size for the query sequence dimension in blockwise attention. blocksize_k: Block size for the key/value sequence dimension in blockwise attention. blocksize_b: Block size for the batch dimension in blockwise attention (often 1). """ runtime_dtype: jax.typing.DTypeLike runtime_softmax_dtype: jax.typing.DTypeLike | None = None sequence_axis_name: str = NOT_GIVEN platform: EasyDeLPlatforms = NOT_GIVEN backend: EasyDeLBackends = NOT_GIVEN partition_axis: PartitionAxis = NOT_GIVEN partition_manager: PartitionManager = NOT_GIVEN base_config: EasyDeLBaseConfig | None = None operation_configs: dict[str, BaseOperationConfig] | None = None _stored_mesh: jax.sharding.Mesh | None = NOT_GIVEN def __post_init__(self) -> None: """ Initializes default values and performs safety checks after dataclass creation. Sets reasonable defaults for various parameters if they are not provided (or marked as Ellipsis). It attempts to source defaults from the `base_config` if available. It also infers the JAX mesh and backend if not explicitly given. Finally, it performs a safety check to ensure no essential attributes remain uninitialized (as Ellipsis). """ from easydel.infra.etils import EasyDeLBackends # fmt:off self.set_attrs_carefully("runtime_dtype", jnp.float32, "attn_dtype") self.set_attrs_carefully("runtime_softmax_dtype", jnp.float32, "attn_softmax_dtype") self.set_attrs_carefully("partition_axis", PartitionAxis()) self.set_attrs_carefully("partition_manager", PartitionManager(self.partition_axis)) # DON'T READ FROM CONFIG self.set_attrs_carefully("sequence_axis_name", "sp", "sequence_axis_name", use_base_config=False) self.set_attrs_carefully("backend", jax.default_backend(), "backend") self.set_attrs_carefully("platform", NOT_GIVEN, "platform") self.set_attrs_carefully("_stored_mesh", NOT_GIVEN, "mesh") self.set_attrs_carefully("operation_configs", None, "operation_configs") # fmt:on if self._stored_mesh is NOT_GIVEN and self.base_config is None: mesh: jax.sharding.Mesh = jax.interpreters.pxla.thread_resources.env.physical_mesh assert not mesh.empty, ( "You should pass 'mesh' to `OperationMetadata` or at least create that under mesh context manager" ) self._stored_mesh = mesh self._safety_check() if self.backend is None: current_backend: str = jax.default_backend() backend_enum: EasyDeLBackends = getattr( EasyDeLBackends, current_backend, getattr(EasyDeLBackends, current_backend.upper()) ) self.backend = backend_enum def _safety_check(self) -> None: """Ensures no essential attributes are left uninitialized (as NOT_GIVEN).""" field: dataclasses.Field for field in dataclasses.fields(self): val: tp.Any = getattr(self, field.name) if val is NOT_GIVEN: raise ValueError(f"`{field.name}` shouldn't be ellipsis")
[docs] @classmethod def from_config(cls, config: EasyDeLBaseConfig) -> OperationMetadata: """ Factory method to create OperationMetadata from an EasyDeLBaseConfig. Args: config: The base configuration object (e.g., model config). softmax_scale: The attention softmax scaling factor. Usually calculated based on head dimension. dropout_prob: The attention dropout probability. Defaults to 0.0. Returns: An initialized OperationMetadata instance. """ return cls( runtime_dtype=config.attn_dtype, runtime_softmax_dtype=config.attn_softmax_dtype, sequence_axis_name=config.sequence_axis_name, platform=config.platform, backend=config.backend, partition_axis=config.partition_axis, base_config=config, operation_configs=getattr(config, "operation_configs", None), )
@property def mesh(self) -> jax.sharding.Mesh | None: """Get current mesh from base_config if available, otherwise return stored mesh.""" if self.base_config is not None: return self.base_config.mesh return self._stored_mesh @mesh.setter def mesh(self, value: jax.sharding.Mesh | None): """Set mesh value for cases where base_config is not available.""" self._stored_mesh = value
[docs] def get_shardings( self, mode: RUNTIME_MODE_TYPES, # type:ignore layout: tp.Literal["bthd", "bhtd", "thd"] = "bthd", qkv_mni_sharding: bool = False, softmax_aux: jaxtyping.Array | None = None, ) -> AttnShardingRules: """ Generates JAX PartitionSpecs for attention tensors based on runtime mode. Args: mode: Runtime mode (e.g., training, inference) for partition resolution. layout: Tensor layout format - "bthd" (batch, time, heads, dim) or "bhtd" (batch, heads, time, dim). qkv_mni_sharding: If True, use HEAD/HEAD_DIM for K/V instead of KV_HEAD/KV_HEAD_DIM. Useful for multi-head attention (MHA) vs grouped-query attention (GQA/MQA). softmax_aux_2d: If True, create sharding for 2D softmax auxiliary outputs (e.g., log-sum-exp, max values) with shape [batch, num_heads]. Returns: AttnShardingRules: Named tuple containing PartitionSpecs for all attention tensors: - query, key, value: Main attention tensors - bias: Attention bias tensor - mask: Attention mask tensor - output: Attention output tensor - q_segment_ids: Query segment IDs (for packed sequences) - kv_segment_ids: Key/value segment IDs (for packed sequences) - softmax_aux: Optional 2D softmax auxiliary output sharding """ pama: PartitionManager = self.partition_manager _h: common_types.DynamicShardingAxes = HEAD if qkv_mni_sharding else KV_HEAD _kvh: common_types.DynamicShardingAxes = HEAD_DIM if qkv_mni_sharding else KV_HEAD_DIM q_sharding: jax.sharding.PartitionSpec k_sharding: jax.sharding.PartitionSpec v_sharding: jax.sharding.PartitionSpec q_segment_ids_sharding: jax.sharding.PartitionSpec kv_segment_ids_sharding: jax.sharding.PartitionSpec if layout == "bthd": q_sharding = pama.resolve(axes=[BATCH, QUERY_LENGTH, HEAD, HEAD_DIM], mode=mode) k_sharding = pama.resolve(axes=[BATCH, KV_LENGTH, _h, _kvh], mode=mode) v_sharding = pama.resolve(axes=[BATCH, KV_LENGTH, _h, _kvh], mode=mode) q_segment_ids_sharding = pama.resolve(axes=[BATCH, QUERY_LENGTH], mode=mode) kv_segment_ids_sharding = pama.resolve(axes=[BATCH, KV_LENGTH], mode=mode) elif layout == "bhtd": q_sharding = pama.resolve(axes=[BATCH, HEAD, QUERY_LENGTH, HEAD_DIM], mode=mode) k_sharding = pama.resolve(axes=[BATCH, _h, KV_LENGTH, _kvh], mode=mode) v_sharding = pama.resolve(axes=[BATCH, _h, KV_LENGTH, _kvh], mode=mode) q_segment_ids_sharding = pama.resolve(axes=[BATCH, QUERY_LENGTH], mode=mode) kv_segment_ids_sharding = pama.resolve(axes=[BATCH, KV_LENGTH], mode=mode) else: raise NotImplementedError(f"Layout '{layout}' is not implemented") qk_extern: tuple[common_types.DynamicShardingAxes, common_types.DynamicShardingAxes] = ( QUERY_LENGTH, BIAS_KV_SEQ, ) b_sharding: jax.sharding.PartitionSpec = pama.resolve(axes=[BATCH, BIAS_HEAD_SEQ, *qk_extern], mode=mode) m_sharding: jax.sharding.PartitionSpec = pama.resolve(axes=[BATCH, None, *qk_extern], mode=mode) # Softmax auxiliary output sharding (e.g., LSE, max) - 2D: [batch, num_heads] softmax_aux_sharding: jax.sharding.PartitionSpec | None = None if softmax_aux is not None: num_dims: int = softmax_aux.ndim if num_dims == 2: softmax_aux_sharding = pama.resolve(axes=[EMPTY, KV_HEAD], mode=mode) else: softmax_aux_sharding = pama.resolve(axes=[HEAD], mode=mode) query3d_sharding: jax.sharding.PartitionSpec = pama.resolve(axes=[BATCH, HEAD, HEAD_DIM], mode=mode) rules: AttnShardingRules = AttnShardingRules( query3d=query3d_sharding, query=q_sharding, key=k_sharding, value=v_sharding, bias=b_sharding, mask=m_sharding, output=q_sharding, q_segment_ids=q_segment_ids_sharding, kv_segment_ids=kv_segment_ids_sharding, softmax_aux=softmax_aux_sharding, ) return rules
[docs] def set_attrs_carefully( self, attr_name: str, default: tp.Any | None, pickup_name: str | None = None, use_base_config: bool = True, ) -> None: """ Internal helper to set an attribute if it's not already set (or is Ellipsis). Optionally retrieves the value from `self.base_config` using `pickup_name` (or `attr_name` if `pickup_name` is None). Args: attr_name: The name of the attribute to set on `self`. default: The default value to use if not found in `base_config` or if `use_base_config` is False. pickup_name: The name of the attribute to look for in `base_config`. Defaults to `attr_name`. use_base_config: Whether to attempt retrieving the value from `base_config`. """ has_attr: bool = hasattr(self, attr_name) current_val: tp.Any = getattr(self, attr_name, NOT_GIVEN) if not has_attr or current_val is NOT_GIVEN: pn: str = attr_name if pickup_name is None else pickup_name should_use_default: bool = self.base_config is None or not use_base_config new_value: tp.Any = default if should_use_default else getattr(self.base_config, pn, default) setattr(self, attr_name, new_value)
[docs] def get_operation_config(self, impl_name: str) -> "BaseOperationConfig | None": """Get ejkernel config for a specific operation by its registered name. Args: impl_name: The operation implementation name (must match OperationRegistry). Valid names: - "flash_attn2": Flash attention 2 implementation - "ring": Ring attention - "blocksparse": Block sparse attention - "ragged_page_attention_v2": Ragged page attention v2 - "ragged_page_attention_v3": Ragged page attention v3 - "sdpa": Scaled dot product attention - "vanilla": Vanilla attention Returns: The operation config if set, otherwise None (which enables ejkernel autotune). Example: >>> cfg = metadata.get_operation_config("flash_attn2") >>> if cfg is not None: ... # Use explicit config ... flash_attention(..., cfg=cfg) >>> else: ... # Use autotune ... flash_attention(..., cfg=None) """ if self.operation_configs is None: return None return self.operation_configs.get(impl_name)