easydel.layers.operations._operation_impl#

Core attention operator implementation framework for EasyDeL.

This module provides the foundational classes and abstractions for implementing various attention mechanisms in JAX. It includes:

  • OperationOutput: Container for attention computation results

  • OperationMetadata: Configuration and runtime metadata for attention operations

  • OperationImpl: Abstract base class for specific attention implementations

  • OperationRegistry: Plugin system for discovering and managing attention implementations

The module supports multiple attention backends (TPU, GPU, CPU) and provides common utilities for mask handling, head repetition (for GQA/MQA), and sharding specifications for distributed computation.

Key Design Principles: 1. Backend-agnostic interface with backend-specific optimizations 2. Support for various attention patterns (vanilla, flash, ring, etc.) 3. Efficient handling of different tensor layouts (BTHD vs BHTD) 4. Integration with JAX’s sharding and parallelism features 5. Flexible metadata system for runtime configuration

Example

>>> from easydel.layers.attention_operator import OperationMetadata, OperationRegistry
>>>
>>> # Create metadata for attention configuration
>>> metadata = OperationMetadata(
...     runtime_dtype=jnp.float16,
...     softmax_scale=1.0 / math.sqrt(head_dim),
...     dropout_prob=0.1
... )
>>>
>>> # Get and instantiate a specific attention implementation
>>> attn_impl = OperationRegistry.create("flash", metadata)
>>>
>>> # Use the attention implementation
>>> output = attn_impl(query, key, value, mask=attention_mask)
class easydel.layers.operations._operation_impl.OperationImpl(metadata: OperationMetadata)[source]#

Bases: BaseOperation

Abstract Base Class for specific attention implementations.

Inherits from BaseOperation to leverage backend-specific dispatching. Subclasses must implement the core attention logic (forward_native) and potentially provide optimized versions for TPU (forward_tpu), GPU (forward_gpu), etc. They also need to declare their name and associated metadata.

Provides common helper methods for attention processing like mask manipulation, head repeating (for GQA/MQA), and determining runtime mode.

create_stable_sharding(state_ps: jax.sharding.PartitionSpec | None = None, preserved_indices: list[int] | None = None, clone_ps: jax.sharding.PartitionSpec | None = None, dep: jax.sharding.PartitionSpec | bool | None = True, tensor: jaxtyping.Float[Array, '...'] | None = None) jax.sharding.PartitionSpec | None[source]#

Helper to create a PartitionSpec, potentially preserving only certain axes.

This might be used for ensuring intermediate tensors or states have compatible sharding, possibly replicating across axes not specified in preserved_indices.

Parameters
  • state_ps – The base PartitionSpec to modify.

  • preserved_indices – A list of dimension indices whose partitioning should be kept from state_ps (or clone_ps if provided). Other dimensions will be set to None (replicated). If None, state_ps is returned.

  • clone_ps – An optional PartitionSpec to copy axis names from for the preserved indices, instead of using state_ps.

  • dep – A dependency flag or PartitionSpec. If None, returns None. Defaults to True. (The exact purpose might be context-specific, potentially for control flow).

  • tensor – Optional tensor to get corrected sharding for.

Returns

A new PartitionSpec with only specified axes partitioned, or None based on dep. Returns state_ps directly if preserved_indices is None.

get_mode(query: Float[Array, 'batch ... num_heads head_dim'], BTHD: bool = True) Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'][source]#

Determines the runtime mode (normal or generation) based on query shape.

Assumes generation mode if the query sequence length dimension is 1.

Parameters
  • query – The query tensor.

  • BTHD – Boolean indicating tensor layout (True for B, T, H, D; False for B, H, T, D).

static repeat_kv_heads(k: Float[Array, 'batch seq_len num_kv_heads head_dim'], v: Float[Array, 'batch seq_len num_kv_heads head_dim'], num_reps: int) tuple[jaxtyping.Float[Array, 'batch seq_len num_q_heads head_dim'], jaxtyping.Float[Array, 'batch seq_len num_q_heads head_dim']][source]#

Repeats Key and Value heads for Grouped Query Operation (GQA) or Multi-Query Operation (MQA).

Expands the head dimension of K and V tensors to match the number of query heads.

Parameters
  • k – Key tensor, assumes shape (batch, seq_len, num_kv_heads, head_dim).

  • v – Value tensor, assumes shape (batch, seq_len, num_kv_heads, head_dim).

  • num_reps – The number of times to repeat each KV head (num_q_heads // num_kv_heads).

Returns

A tuple (k_repeated, v_repeated) with shapes (batch, seq_len, num_q_heads, head_dim).

class easydel.layers.operations._operation_impl.OperationOutput[source]#

Bases: object

This dataclass encapsulates the results computation

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.