easydel.layers.attention_operator.modules.paged_attn#
- class easydel.layers.attention_operator.modules.paged_attn.PagedAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAttention implementation using the Paged Attention mechanism with TPU Pallas kernels.
This class provides an attention mechanism suitable for scenarios where the Key-Value cache is managed in non-contiguous pages (Paged KV Cache). It leverages specialized kernels for efficient execution on TPUs, handling prefill and decode phases separately or in a mixed mode.
- metadata#
Configuration metadata for the attention mechanism. While this class uses AttentionMetadata, it primarily relies on the additional PagedAttentionMetadata passed during the forward call for paged-specific information.
- Type
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass.
- Raises
NotImplementedError – Paged Attention currently relies on Pallas for TPUs and does not have a specific CUDA implementation. (Future work might add this).
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass.
- Raises
NotImplementedError – Paged Attention currently relies on Pallas for TPUs and does not have a specific ROCm implementation.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
Generic GPU forward pass.
- Raises
NotImplementedError – Paged Attention relies on specific kernels (currently Pallas for TPU) and does not have a generic GPU implementation.
- forward_native(*args, **kwargs) AttentionOutput[source]#
Native (CPU) forward pass.
- Raises
NotImplementedError – Paged Attention requires specialized kernels and does not have a native CPU implementation.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Not implemented for Paged Attention.
- forward_tpu(q: Array, k: Array, v: Array, cache_view: PagedAttentionCacheView, cache_metadata: PagedAttentionMetadata, **ignore) AttentionOutput[source]#
TPU forward pass for Paged Attention.
Determines the execution mode (prefill, decode, or mixed) based on the provided cache_metadata and dispatches the computation to the corresponding internal TPU method (_prefill_tpu, _decode_tpu, _mixed_tpu).
- Parameters
q (Array) – Query tensor. Shape depends on mode (prefill/decode/mixed).
k (Array) – Key tensor (ignored).
v (Array) – Value tensor (ignored).
cache_view (PagedAttentionCacheView) – Contains the paged KV cache tensors.
cache_metadata (PagedAttentionMetadata) – Contains metadata describing the state and mode (prefill/decode/mixed) of the current batch.
**ignore – Ignored keyword arguments.
- Returns
- An object containing the computed attention outputs.
Attention weights are typically not computed or returned in paged attention.
- Return type
- get_impl_metadata() AttentionMetadata[source]#
Retrieves the metadata associated with this attention implementation instance.
- Returns
The metadata object provided during initialization.
- Return type