easydel.layers.attention_operator.__init__

Contents

easydel.layers.attention_operator.__init__#

class easydel.layers.attention_operator.__init__.AttentionImpl(metadata: AttentionMetadata)[source]#

Bases: BaseOperation

Abstract Base Class for specific attention implementations.

Inherits from BaseOperation to leverage backend-specific dispatching. Subclasses must implement the core attention logic (forward_native) and potentially provide optimized versions for TPU (forward_tpu), GPU (forward_gpu), etc. They also need to declare their name and associated metadata.

Provides common helper methods for attention processing like mask manipulation, head repeating (for GQA/MQA), and determining runtime mode.

create_stable_sharding(state_ps: Optional[PartitionSpec] = None, preserved_indices: List[int] = None, clone_ps: Optional[PartitionSpec] = None, dep: Optional[Union[PartitionSpec, bool]] = True, tensor: Optional[Array] = None) Optional[PartitionSpec][source]#

Helper to create a PartitionSpec, potentially preserving only certain axes.

This might be used for ensuring intermediate tensors or states have compatible sharding, possibly replicating across axes not specified in preserved_indices.

Parameters
  • state_ps – The base PartitionSpec to modify.

  • preserved_indices – A list of dimension indices whose partitioning should be kept from state_ps (or clone_ps if provided). Other dimensions will be set to None (replicated). If None, state_ps is returned.

  • clone_ps – An optional PartitionSpec to copy axis names from for the preserved indices, instead of using state_ps.

  • dep – A dependency flag or PartitionSpec. If None, returns None. Defaults to True. (The exact purpose might be context-specific, potentially for control flow).

Returns

A new PartitionSpec with only specified axes partitioned, or None based on dep. Returns state_ps directly if preserved_indices is None.

current_backend() Literal['tpu', 'gpu', 'cpu'][source]#

Returns the current JAX default backend as a lowercase string literal.

Returns

“tpu”, “gpu”, or “cpu”.

abstract get_impl_metadata() AttentionMetadata[source]#

Returns the AttentionMetadata associated with this implementation instance.

Returns

The AttentionMetadata instance passed during initialization.

abstract classmethod get_impl_name() Union[str, Tuple[str, ...]][source]#

Returns the unique name(s) identifying this attention implementation.

Used by the AttentionRegistry. Can return a single string or a tuple/list of strings if the implementation has multiple aliases.

Returns

A string or tuple/list of strings representing the implementation name(s).

get_mode(q: Array, BTHD: bool = True) Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'][source]#

Determines the runtime mode (normal or generation) based on query shape.

Assumes generation mode if the query sequence length dimension is 1.

Parameters
  • q – The query tensor.

  • BTHD – Boolean indicating tensor layout (True for B, T, H, D; False for B, H, T, D).

static repeat_kv_heads(k: Array, v: Array, num_reps: int) Tuple[Array, Array][source]#

Repeats Key and Value heads for Grouped Query Attention (GQA) or Multi-Query Attention (MQA).

Expands the head dimension of K and V tensors to match the number of query heads.

Parameters
  • k – Key tensor, assumes shape (batch, seq_len, num_kv_heads, head_dim).

  • v – Value tensor, assumes shape (batch, seq_len, num_kv_heads, head_dim).

  • num_reps – The number of times to repeat each KV head (num_q_heads // num_kv_heads).

Returns

A tuple (k_repeated, v_repeated) with shapes (batch, seq_len, num_q_heads, head_dim).

class easydel.layers.attention_operator.__init__.AttentionMetadata(runtime_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType], runtime_softmax_dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = None, sequence_axis_name: str = <eformer.common_types._Empty object>, mesh: ~typing.Optional[~jax._src.mesh.Mesh] = <eformer.common_types._Empty object>, platform: ~easydel.infra.etils.EasyDeLPlatforms = <eformer.common_types._Empty object>, backend: ~easydel.infra.etils.EasyDeLBackends = <eformer.common_types._Empty object>, partition_axis: ~eformer.escale.partition.manager.PartitionAxis = <eformer.common_types._Empty object>, partition_manager: ~eformer.escale.partition.manager.PartitionManager = <eformer.common_types._Empty object>, base_config: ~typing.Optional[~easydel.infra.base_config.EasyDeLBaseConfig] = None, scan_ring_attention: bool = <eformer.common_types._Empty object>, softmax_scale: float = <eformer.common_types._Empty object>, dropout_prob: float = <eformer.common_types._Empty object>, blocksize_q: int = <eformer.common_types._Empty object>, blocksize_k: int = <eformer.common_types._Empty object>, blocksize_b: int = <eformer.common_types._Empty object>)[source]#

Bases: object

Holds configuration, context, and metadata for attention operations.

This class centralizes various parameters needed by different attention implementations, facilitating consistent behavior and configuration. It handles default values and can be initialized from an EasyDeLBaseConfig.

runtime_dtype#

The primary JAX dtype for computations (e.g., q, k, v).

Type

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]

runtime_softmax_dtype#

Optional JAX dtype for the softmax computation, allowing for higher precision if needed (e.g., float32).

Type

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]]

sequence_axis_name#

The name used for the sequence axis in JAX parallelism (axis_names for pjit).

Type

str

mesh#

The JAX device mesh for distributed computation. Must be provided or inferred from context.

Type

Optional[jax._src.mesh.Mesh]

platform#

The target hardware platform (e.g., TPU, GPU).

Type

easydel.infra.etils.EasyDeLPlatforms

backend#

The specific JAX backend being used (e.g., TPU, CUDA, ROCM).

Type

easydel.infra.etils.EasyDeLBackends

partition_axis#

Configuration for partitioning axes in distributed settings. (Likely from eformer.escale).

Type

eformer.escale.partition.manager.PartitionAxis

base_config#

An optional reference to the base model configuration object for sourcing default values.

Type

Optional[easydel.infra.base_config.EasyDeLBaseConfig]

scan_ring_attention#

Boolean flag indicating whether to use ring attention via jax.lax.scan.

Type

bool

softmax_scale#

The scaling factor applied before the softmax operation. Often 1 / sqrt(head_dim).

Type

float

dropout_prob#

The dropout probability applied to attention weights.

Type

float

blocksize_q#

Block size for the query sequence dimension in blockwise attention.

Type

int

blocksize_k#

Block size for the key/value sequence dimension in blockwise attention.

Type

int

blocksize_b#

Block size for the batch dimension in blockwise attention (often 1).

Type

int

backend: EasyDeLBackends = <eformer.common_types._Empty object>#
base_config: Optional[EasyDeLBaseConfig] = None#
blocksize_b: int = <eformer.common_types._Empty object>#
blocksize_k: int = <eformer.common_types._Empty object>#
blocksize_q: int = <eformer.common_types._Empty object>#
dropout_prob: float = <eformer.common_types._Empty object>#
classmethod from_config(config: EasyDeLBaseConfig, softmax_scale: float, dropout_prob: float = 0.0) AttentionMetadata[source]#

Factory method to create AttentionMetadata from an EasyDeLBaseConfig.

Parameters
  • config – The base configuration object (e.g., model config).

  • softmax_scale – The attention softmax scaling factor. Usually calculated based on head dimension.

  • dropout_prob – The attention dropout probability. Defaults to 0.0.

Returns

An initialized AttentionMetadata instance.

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

get_shardings(mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], BTHD: bool = True, qkv_mni_sharding: bool = False)[source]#

Generates JAX PartitionSpecs for attention tensors based on runtime mode.

Parameters
  • mode – The current runtime mode (normal or generation).

  • BTHD – Boolean indicating tensor layout. True for (Batch, Time, Head, Dim), False for (Batch, Head, Time, Dim).

Returns

(query, key, value, bias, mask, attention_output)

Return type

A tuple containing PartitionSpecs for

mesh: Optional[Mesh] = <eformer.common_types._Empty object>#
partition_axis: PartitionAxis = <eformer.common_types._Empty object>#
partition_manager: PartitionManager = <eformer.common_types._Empty object>#
platform: EasyDeLPlatforms = <eformer.common_types._Empty object>#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

runtime_dtype: Union[str, type[Any], dtype, SupportsDType]#
runtime_softmax_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None#
scan_ring_attention: bool = <eformer.common_types._Empty object>#
sequence_axis_name: str = <eformer.common_types._Empty object>#
set_attrs_carefully(attr_name: str, default: Optional[Any], pickup_name: Optional[str] = None, use_base_config: bool = True)[source]#

Internal helper to set an attribute if it’s not already set (or is Ellipsis).

Optionally retrieves the value from self.base_config using pickup_name (or attr_name if pickup_name is None).

Parameters
  • attr_name – The name of the attribute to set on self.

  • default – The default value to use if not found in base_config or if use_base_config is False.

  • pickup_name – The name of the attribute to look for in base_config. Defaults to attr_name.

  • use_base_config – Whether to attempt retrieving the value from base_config.

softmax_scale: float = <eformer.common_types._Empty object>#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.attention_operator.__init__.AttentionOutput(attention_weights: Optional[Array] = None, attention_outputs: Optional[Array] = None)[source]#

Bases: object

Container for the outputs of an attention operation.

attention_weights#

The attention probabilities, typically of shape (batch, num_heads, query_seq_len, key_value_seq_len). Optional.

Type

Optional[jax.Array]

attention_outputs#

The final weighted sum of values, typically of shape (batch, query_seq_len, num_heads, head_dim) or (batch, num_heads, query_seq_len, head_dim). Optional.

Type

Optional[jax.Array]

attention_outputs: Optional[Array] = None#
attention_weights: Optional[Array] = None#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.attention_operator.__init__.AttentionRegistry[source]#

Bases: object

Registry for discovering and managing different AttentionImpl classes.

Allows registering implementations using a decorator and retrieving or instantiating them by name.

classmethod create(impl_name: str, metadata: AttentionMetadata) AttentionImpl[source]#

Creates an instance of an attention implementation by name.

Retrieves the class associated with impl_name and initializes it with the provided metadata.

Parameters
  • impl_name – The name of the implementation to instantiate.

  • metadata – The AttentionMetadata to pass to the implementation’s constructor.

Returns

An initialized instance of the requested AttentionImpl subclass.

Raises

ValueError – If no implementation is registered with impl_name.

classmethod get(impl_name: str) Type[AttentionImpl][source]#

Retrieves an attention implementation class by its registered name.

Parameters

impl_name – The name of the implementation to retrieve.

Returns

The AttentionImpl subclass registered under the given name.

Raises

ValueError – If no implementation is registered with that name.

classmethod list_implementations() List[str][source]#

Returns a list of names of all registered attention implementations.

Returns

A list of strings, where each string is a registered implementation name.

classmethod register(impl_cls: Type[ICa]) Type[ICa][source]#

Class method decorator to register an AttentionImpl subclass.

The implementation is registered under the name(s) returned by its get_impl_name() class method.

Example: ```python @AttentionRegistry.register class FlashAttentionImpl(AttentionImpl):

@classmethod def get_impl_name(cls) -> str:

return “flash”

# … implementation …

```

Parameters

impl_cls – The AttentionImpl subclass to register.

Returns

The registered class itself.

class easydel.layers.attention_operator.__init__.AutoRegressiveDecodeAttn(metadata: AttentionMetadata)[source]#

Bases: AttentionImpl

Attention implementation tailored for the autoregressive decoding step.

This class handles the attention mechanism when generating tokens one by one, attending to the previously generated sequence stored in a cache. It utilizes shard_map for distributed computation and supports different backends, including a potential Pallas-optimized version for TPUs. It assumes the query sequence length is 1.

metadata#

Configuration metadata for the attention mechanism.

Type

AttentionMetadata

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass for autoregressive decoding attention.

Delegates to the native JAX/XLA implementation (forward_native).

forward_cuda(*args, **kwargs) AttentionOutput[source]#

CUDA GPU forward pass for autoregressive decoding attention.

Delegates to the native JAX/XLA implementation (forward_native). Future optimizations might add CUDA-specific kernels here.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass for autoregressive decoding attention.

Currently delegates to forward_cuda.

forward_native(q: Array, k: Array, v: Array, cache_view: TransformerCacheView, **ignores) AttentionOutput[source]#

Performs the native JAX/XLA forward pass for autoregressive decoding attention.

This implementation uses shard_map to distribute the computation and relies on standard JAX operations (einsum, softmax). It calculates attention weights between the single query token and the keys in the cache, applies masking based on the valid range defined in cache_view, computes the softmax, and finally computes the weighted sum of values.

Parameters
  • q (Array) – Query tensor of shape (batch_size, 1, num_query_heads, head_dim).

  • k (Array) – Key tensor (from cache) of shape (batch_size, kv_sequence_length, num_kv_heads, head_dim).

  • v (Array) – Value tensor (from cache) of shape (batch_size, kv_sequence_length, num_kv_heads, head_dim).

  • cache_view (TransformerCacheView) – Contains metadata about the cache, specifically starts (start index for valid keys/values) and index (current index or length).

  • **ignores – Ignored keyword arguments.

Returns

An object containing the attention outputs (attention_outputs)

of shape (batch_size, 1, num_query_heads, head_dim). Attention weights are not returned.

Return type

AttentionOutput

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass for autoregressive decoding attention.

Delegates to the native JAX/XLA implementation (forward_native). Future optimizations might add ROCm-specific kernels here.

forward_tpu(*args, **kwargs) AttentionOutput[source]#

TPU forward pass for autoregressive decoding attention.

Currently delegates to the native JAX/XLA implementation (forward_native). Consider using _forward_tpu for potential Pallas optimization if available/enabled.

get_impl_metadata() AttentionMetadata[source]#

Returns the metadata associated with this attention implementation instance.

Returns

The AttentionMetadata provided during initialization.

classmethod get_impl_name() Union[str, Tuple[str]][source]#

Returns the registered name of this attention implementation.

Returns

The string “autoregressive_decodeattn”.

class easydel.layers.attention_operator.__init__.FlashAttn(metadata: AttentionMetadata)[source]#

Bases: AttentionImpl

An implementation of Flash Attention V2 using specialized JAX primitives.

This class leverages jax.experimental.pallas.ops.tpu.flash_attention for TPUs and a Triton kernel (triton_flash_attention) for GPUs (CUDA). It is registered under the name “flash_attn2”. CPU execution is not supported and will raise an error.

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass. Delegates to forward_native, which raises an error.

Raises

NotImplementedError – Via forward_native.

forward_cuda(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#

Performs Flash Attention V2 on GPU (CUDA) using a Triton kernel.

Handles optional mask/bias, KV head repetition (inside the kernel or before), and sharding based on metadata. Assumes triton_flash_attention handles KV head repetition internally if needed or expects broadcastable KV. Assumes BTHD input/output format for the Triton kernel.

Parameters
  • q – Query tensor, expected shape (batch, q_seq_len, num_q_heads, head_dim).

  • k – Key tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).

  • v – Value tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).

  • mask – Optional boolean attention mask. Used by the kernel if bias is not provided. Shape typically (batch, 1, q_seq_len, kv_seq_len) or broadcastable.

  • bias – Optional attention bias tensor. Added by the kernel. Shape typically (batch, num_heads, q_seq_len, kv_seq_len) or broadcastable. Takes precedence over mask within the kernel logic if both are somehow passed.

  • init_bias – Optional callable function to initialize bias if mask and bias are None.

  • causal – If True, instructs the kernel to apply a causal mask.

  • **ignore – Ignored keyword arguments.

Returns

An AttentionOutput object containing the attention outputs. Attention weights are typically not computed or returned by Flash Attention implementations.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Delegates to the CUDA-specific implementation.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#

Native (CPU) forward pass for Flash Attention. Not implemented.

Raises

NotImplementedError – Flash Attention is not supported on CPU via this implementation.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass. Not implemented, falls back to native (error).

Raises

NotImplementedError – Via forward_native. ROCm support requires a specific kernel.

forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#

Performs Flash Attention V2 on TPU using jax.experimental.pallas.ops.tpu.flash_attention.

Handles optional mask/bias, KV head repetition, and sharding based on metadata. Note: The Pallas implementation expects inputs in BHTD format.

Parameters
  • q – Query tensor, expected shape (batch, q_seq_len, num_q_heads, head_dim).

  • k – Key tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).

  • v – Value tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).

  • mask – Optional boolean attention mask. Bias will be generated from this if provided. Shape typically (batch, 1, q_seq_len, kv_seq_len) or broadcastable.

  • bias – Optional attention bias tensor. Added directly to attention scores. Shape typically (batch, num_heads, q_seq_len, kv_seq_len) or broadcastable. Takes precedence over mask.

  • init_bias – Optional callable function to initialize bias if mask and bias are None.

  • causal – If True, applies a causal mask. Ignored if q_seq_len is 1 (generation).

  • **ignore – Ignored keyword arguments.

Returns

An AttentionOutput object containing the attention outputs. Attention weights are typically not computed or returned by Flash Attention implementations.

get_impl_metadata() AttentionMetadata[source]#

Returns the metadata associated with this attention implementation instance.

Returns

The AttentionMetadata provided during initialization.

classmethod get_impl_name() Union[str, Tuple[str]][source]#

Returns the registered name of this attention implementation.

Returns

The string “flash_attn2”.

class easydel.layers.attention_operator.__init__.PagedAttn(metadata: AttentionMetadata)[source]#

Bases: AttentionImpl

Attention implementation using the Paged Attention mechanism with TPU Pallas kernels.

This class provides an attention mechanism suitable for scenarios where the Key-Value cache is managed in non-contiguous pages (Paged KV Cache). It leverages specialized kernels for efficient execution on TPUs, handling prefill and decode phases separately or in a mixed mode.

metadata#

Configuration metadata for the attention mechanism. While this class uses AttentionMetadata, it primarily relies on the additional PagedAttentionMetadata passed during the forward call for paged-specific information.

Type

AttentionMetadata

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CUDA GPU forward pass.

Raises

NotImplementedError – Paged Attention currently relies on Pallas for TPUs and does not have a specific CUDA implementation. (Future work might add this).

forward_cuda(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass.

Raises

NotImplementedError – Paged Attention currently relies on Pallas for TPUs and does not have a specific ROCm implementation.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

Generic GPU forward pass.

Raises

NotImplementedError – Paged Attention relies on specific kernels (currently Pallas for TPU) and does not have a generic GPU implementation.

forward_native(*args, **kwargs) AttentionOutput[source]#

Native (CPU) forward pass.

Raises

NotImplementedError – Paged Attention requires specialized kernels and does not have a native CPU implementation.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass. Not implemented for Paged Attention.

forward_tpu(q: Array, k: Array, v: Array, cache_view: PagedAttentionCacheView, cache_metadata: PagedAttentionMetadata, **ignore) AttentionOutput[source]#

TPU forward pass for Paged Attention.

Determines the execution mode (prefill, decode, or mixed) based on the provided cache_metadata and dispatches the computation to the corresponding internal TPU method (_prefill_tpu, _decode_tpu, _mixed_tpu).

Parameters
  • q (Array) – Query tensor. Shape depends on mode (prefill/decode/mixed).

  • k (Array) – Key tensor (ignored).

  • v (Array) – Value tensor (ignored).

  • cache_view (PagedAttentionCacheView) – Contains the paged KV cache tensors.

  • cache_metadata (PagedAttentionMetadata) – Contains metadata describing the state and mode (prefill/decode/mixed) of the current batch.

  • **ignore – Ignored keyword arguments.

Returns

An object containing the computed attention outputs.

Attention weights are typically not computed or returned in paged attention.

Return type

AttentionOutput

get_impl_metadata() AttentionMetadata[source]#

Retrieves the metadata associated with this attention implementation instance.

Returns

The metadata object provided during initialization.

Return type

AttentionMetadata

classmethod get_impl_name() Union[str, Tuple[str]][source]#

Returns the registered name for this attention implementation.

Returns

The name “paged_attention”.

Return type

tp.Union[str, tp.Tuple[str]]

class easydel.layers.attention_operator.__init__.RingAttn(metadata: AttentionMetadata)[source]#

Bases: AttentionImpl

Attention implementation using ring-passing algorithm or blockwise scan.

This implementation supports: - Native (scan-based) blockwise attention via blockwise_attn. - TPU-specific ring attention using pallas_ring_attention kernel.

It is registered under the name “ring”.

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass. Delegates to forward_native (scan-based).

forward_cuda(*args, **kwargs) AttentionOutput[source]#

CUDA GPU forward pass. Currently delegates to forward_native (scan-based).

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Currently delegates to forward_native (scan-based).

forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: Optional[PRNGKey] = None, causal: bool = True, **ignore) AttentionOutput[source]#

Computes attention using the scan-based blockwise_attn function.

Handles optional mask/bias, KV head repetition, and sharding constraints.

Parameters
  • q – Query tensor (B, T, H, D).

  • k – Key tensor (B, S, H_kv, D).

  • v – Value tensor (B, S, H_kv, D).

  • mask – Optional boolean attention mask (broadcastable to B, 1, T, S).

  • bias – Optional attention bias (broadcastable to B, H, T, S).

  • init_bias – Optional callable to initialize bias if mask/bias are None.

  • deterministic – If False, enables dropout. Requires dropout_rng.

  • dropout_rng – JAX PRNG key for dropout if deterministic is False.

  • causal – Apply causal mask if True.

  • **ignore – Ignored keyword arguments.

Returns

AttentionOutput containing the attention result.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass. Currently delegates to forward_native (scan-based).

forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: Optional[PRNGKey] = None, causal: bool = True, **ignore) AttentionOutput[source]#

Computes Ring Attention on TPU using the pallas_ring_attention kernel.

Handles optional mask/bias, sharding, and passes configuration to the kernel.

Parameters
  • q – Query tensor (B, T, H, D).

  • k – Key tensor (B, S, H_kv, D).

  • v – Value tensor (B, S, H_kv, D).

  • mask – Optional boolean attention mask (broadcastable to B, 1, T, S).

  • bias – Optional attention bias (broadcastable to B, H, T, S).

  • init_bias – Optional callable to initialize bias if mask/bias are None.

  • deterministic – If False, potentially enables dropout within the kernel (if supported).

  • dropout_rng – JAX PRNG key (may be used by the kernel if dropout is enabled).

  • causal – Apply causal mask if True. Passed to the kernel.

  • **ignore – Ignored keyword arguments.

Returns

AttentionOutput containing the attention result.

get_impl_metadata() AttentionMetadata[source]#

Returns the metadata associated with this instance.

classmethod get_impl_name() Union[str, Tuple[str]][source]#

Returns the registered name: “ring”.

class easydel.layers.attention_operator.__init__.ScaledDotProductAttn(metadata: AttentionMetadata)[source]#

Bases: AttentionImpl

An attention implementation that leverages jax.nn.dot_product_attention.

This class utilizes JAX’s optimized SDPA primitive, which can dispatch to different backend implementations (like XLA, cuDNN, or potentially Flash Attention emulation on CUDA depending on JAX version and hardware).

It handles sharding using shard_map and manages backend-specific dispatch (primarily distinguishing between CUDA/GPU and other backends like TPU/CPU).

Registered under the names “sdpa”, “cudnn”, and “cuda_flash_attn2”.

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass. Delegates to forward_native (XLA implementation).

forward_cuda(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#

Computes attention using jax.nn.dot_product_attention with the “cudnn” implementation.

This is optimized for NVIDIA GPUs using cuDNN. It applies sharding via shard_map. Note: The cuDNN implementation might have specific requirements (e.g., dtype). Causal masking is disabled during generation mode (q_len=1) as it’s unnecessary.

Parameters
  • q – Query tensor (B, T, H, D).

  • k – Key tensor (B, S, H_kv, D).

  • v – Value tensor (B, S, H_kv, D_v).

  • mask – Optional boolean attention mask (broadcastable to B, 1, T, S).

  • bias – Optional attention bias tensor (broadcastable to B, H, T, S).

  • init_bias – Optional callable to initialize bias if mask/bias are None.

  • causal – If True, applies causal masking within the primitive, unless in generation mode.

  • **ignore – Ignored keyword arguments.

Returns

An AttentionOutput object. Weights are not returned.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Delegates to the CUDA-specific implementation.

forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#

Computes attention using jax.nn.dot_product_attention with the “xla” implementation.

This is typically used for CPU and TPU backends. It applies sharding via shard_map.

Parameters
  • q – Query tensor (B, T, H, D).

  • k – Key tensor (B, S, H_kv, D).

  • v – Value tensor (B, S, H_kv, D_v).

  • mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Passed directly to the primitive.

  • bias – Optional attention bias tensor (broadcastable to B, H, T, S). Passed directly to the primitive. If bias is provided, causal is forced to False.

  • init_bias – Optional callable to initialize bias if mask/bias are None.

  • causal – If True and bias is None, applies causal masking within the primitive.

  • **ignore – Ignored keyword arguments.

Returns

An AttentionOutput object. Note that jax.nn.dot_product_attention typically does not return attention weights.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass. Currently delegates to forward_native.

forward_tpu(*args, **kwargs) AttentionOutput[source]#

TPU forward pass. Delegates to forward_native (XLA implementation).

get_impl_metadata() AttentionMetadata[source]#

Returns the metadata associated with this attention implementation instance.

Returns

The AttentionMetadata provided during initialization.

classmethod get_impl_name() Union[str, Tuple[str]][source]#

Returns the registered name(s) for this implementation.

Returns

(“sdpa”, “cudnn”, “cuda_flash_attn2”).

Return type

A tuple of strings

class easydel.layers.attention_operator.__init__.SplashAttn(metadata: AttentionMetadata)[source]#

Bases: AttentionImpl

An attention implementation using the Pallas Splash Attention kernel for TPUs.

Splash Attention is an optimized attention mechanism designed for TPUs. This implementation provides a wrapper around the make_splash_mqa_single_device primitive.

Note

  • This implementation is primarily intended for TPUs.

  • It falls back to VanillaAttn under certain conditions:
    • Query sequence length is 1 (generation mode).

    • causal is False.

    • Query sequence length is not divisible by 128 (kernel constraint).

  • Non-TPU forward methods (forward_native, forward_gpu, etc.) are not implemented and will raise NotImplementedError.

Registered under the name “splash”.

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass. Not implemented for Splash Attention.

forward_cuda(*args, **kwargs) AttentionOutput[source]#

CUDA GPU forward pass. Not implemented for Splash Attention.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Not implemented for Splash Attention.

forward_native(*args, **kwargs) AttentionOutput[source]#

Native (CPU) forward pass. Not implemented for Splash Attention.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass. Not implemented for Splash Attention.

forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, causal: bool = True, cache_view: Optional[TransformerCacheView] = None, **ignore) AttentionOutput[source]#

Performs Splash Attention on TPU using the Pallas kernel.

Handles fallback logic, mask processing, block size configuration, and sharding via shard_map. Expects inputs potentially in BTHD format and transposes them to BHTD for the kernel.

Parameters
  • q – Query tensor (B, T, Hq, D).

  • k – Key tensor (B, S, Hkv, D).

  • v – Value tensor (B, S, Hkv, Dv).

  • mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Used to generate segment IDs if provided.

  • causal – If True, applies causal masking via the kernel’s mask configuration. If False, falls back to VanillaAttn.

  • **ignore – Ignored keyword arguments.

Returns

An AttentionOutput object containing the attention outputs. Attention weights are not computed or returned by Splash Attention.

get_impl_metadata() AttentionMetadata[source]#

Returns the metadata associated with this attention implementation instance.

Returns

The AttentionMetadata provided during initialization.

classmethod get_impl_name() Union[str, Tuple[str]][source]#

Returns the registered name of this attention implementation.

Returns

The string “splash”.

class easydel.layers.attention_operator.__init__.VanillaAttn(metadata: AttentionMetadata)[source]#

Bases: AttentionImpl

A standard, non-optimized implementation of multi-head attention.

This implementation uses basic JAX operations like jnp.einsum and standard softmax. It serves as a reference implementation and a fallback for platforms where optimized kernels (like Flash Attention) are not available or desired. It supports features like attention bias, masking, dropout, and Grouped Query Attention (GQA)/Multi-Query Attention (MQA) via reshaping.

Registered under the name “vanilla”.

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass. Delegates to forward_native.

forward_cuda(*args, **kwargs) AttentionOutput[source]#

CUDA GPU forward pass. Delegates to forward_native.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Delegates to forward_native.

forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = True, dropout_rng: Optional[PRNGKey] = None, **ignore) AttentionOutput[source]#

Computes multi-head attention using standard JAX operations.

Supports GQA/MQA by reshaping the query tensor to match the number of key/value heads. Applies scaling, optional bias/mask, softmax (potentially in float32), and optional dropout.

Parameters
  • q – Query tensor (B, T, H_q, D).

  • k – Key tensor (B, S, H_kv, D).

  • v – Value tensor (B, S, H_kv, D_v).

  • mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Used if bias is not provided.

  • bias – Optional attention bias tensor (broadcastable to B, H_q, T, S). Takes precedence over mask.

  • init_bias – Optional callable to initialize bias if mask/bias are None.

  • deterministic – If True, disables dropout.

  • dropout_rng – JAX PRNG key for dropout. Required if deterministic is False and dropout_prob > 0.

  • **ignore – Ignored keyword arguments.

Returns

An AttentionOutput object containing the attention weights (if computed) and the final attention outputs.

Raises

NotImplementedError – If the bias head dimension cannot be reshaped correctly to match the query head structure for GQA/MQA.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass. Delegates to forward_native.

forward_tpu(*args, **kwargs) AttentionOutput[source]#

TPU forward pass. Delegates to forward_native.

get_impl_metadata() AttentionMetadata[source]#

Returns the metadata associated with this attention implementation instance.

Returns

The AttentionMetadata provided during initialization.

classmethod get_impl_name() Union[str, Tuple[str]][source]#

Returns the registered name of this attention implementation.

Returns

The string “vanilla”.