easydel.layers.attention_operator.modules.paged_attn#

class easydel.layers.attention_operator.modules.paged_attn.PagedAttn(metadata: AttentionMetadata)[source]#

Bases: AttentionImpl

Attention implementation using the Paged Attention mechanism with TPU Pallas kernels.

This class provides an attention mechanism suitable for scenarios where the Key-Value cache is managed in non-contiguous pages (Paged KV Cache). It leverages specialized kernels for efficient execution on TPUs, handling prefill and decode phases separately or in a mixed mode.

metadata#

Configuration metadata for the attention mechanism. While this class uses AttentionMetadata, it primarily relies on the additional PagedAttentionMetadata passed during the forward call for paged-specific information.

Type

AttentionMetadata

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CUDA GPU forward pass.

Raises

NotImplementedError – Paged Attention currently relies on Pallas for TPUs and does not have a specific CUDA implementation. (Future work might add this).

forward_cuda(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass.

Raises

NotImplementedError – Paged Attention currently relies on Pallas for TPUs and does not have a specific ROCm implementation.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

Generic GPU forward pass.

Raises

NotImplementedError – Paged Attention relies on specific kernels (currently Pallas for TPU) and does not have a generic GPU implementation.

forward_native(*args, **kwargs) AttentionOutput[source]#

Native (CPU) forward pass.

Raises

NotImplementedError – Paged Attention requires specialized kernels and does not have a native CPU implementation.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass. Not implemented for Paged Attention.

forward_tpu(q: Array, k: Array, v: Array, cache_view: PagedAttentionCacheView, cache_metadata: PagedAttentionMetadata, **ignore) AttentionOutput[source]#

TPU forward pass for Paged Attention.

Determines the execution mode (prefill, decode, or mixed) based on the provided cache_metadata and dispatches the computation to the corresponding internal TPU method (_prefill_tpu, _decode_tpu, _mixed_tpu).

Parameters
  • q (Array) – Query tensor. Shape depends on mode (prefill/decode/mixed).

  • k (Array) – Key tensor (ignored).

  • v (Array) – Value tensor (ignored).

  • cache_view (PagedAttentionCacheView) – Contains the paged KV cache tensors.

  • cache_metadata (PagedAttentionMetadata) – Contains metadata describing the state and mode (prefill/decode/mixed) of the current batch.

  • **ignore – Ignored keyword arguments.

Returns

An object containing the computed attention outputs.

Attention weights are typically not computed or returned in paged attention.

Return type

AttentionOutput

get_impl_metadata() AttentionMetadata[source]#

Retrieves the metadata associated with this attention implementation instance.

Returns

The metadata object provided during initialization.

Return type

AttentionMetadata

classmethod get_impl_name() Union[str, Tuple[str]][source]#

Returns the registered name for this attention implementation.

Returns

The name “paged_attention”.

Return type

tp.Union[str, tp.Tuple[str]]