Source code for easydel.kernels.tpu_ops.paged_attention._paged_attention

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import typing as tp

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu

from ._forward_pallas import (
	DEFAULT_MASK_VALUE,
	paged_flash_attention_kernel,
	paged_flash_attention_kernel_inline_seq_dim,
	prefill_attention_impl,
)


@jax.jit
def _build_contiguous_kv_vectorized(
	pages,
	page_indices,
) -> tuple[jnp.ndarray, jnp.ndarray]:
	batch_size = page_indices.shape[0]
	num_heads, _, page_size, head_dim = pages.shape

	def gather_for_head(head_pages, indices_per_batch):
		return jax.vmap(lambda idx: head_pages[idx, :, :])(indices_per_batch)

	gathered_per_head = jax.vmap(gather_for_head, in_axes=(0, None))(pages, page_indices)
	gathered_swapped = gathered_per_head.transpose(1, 0, 2, 3, 4)
	max_seq_len = page_indices.shape[1] * page_size
	return gathered_swapped.reshape(batch_size, num_heads, max_seq_len, head_dim)


@functools.partial(
	jax.jit,
	static_argnames=[
		"block_size",
		"num_total_blocks",
		"max_blocks_per_seq",
		"num_kv_heads",
		"head_dim",
	],
)
def _build_paged_kv(
	contiguous_k: jnp.ndarray,  # Shape: (batch, seq_len, num_kv_heads, head_dim)
	contiguous_v: jnp.ndarray,  # Shape: (batch, seq_len, num_kv_heads, head_dim)
	seq_lengths: jnp.ndarray,  # Shape: (batch,). True length of each sequence.
	block_size: int,
	num_total_blocks: int,  # Desired size of the physical cache.
	max_blocks_per_seq: int,  # Desired size of the block table per sequence.
	num_kv_heads: int,
	head_dim: int,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
	"""
	Builds paged KV caches (physical cache + block tables) from contiguous KV caches.
	"""
	batch_size, _, _, _ = contiguous_k.shape  # Use different name
	assert contiguous_v.shape == contiguous_k.shape
	assert seq_lengths.shape == (batch_size,)

	assert contiguous_k.shape[-2] == num_kv_heads
	assert contiguous_k.shape[-1] == head_dim

	# --- 1. Calculate block requirements and allocate physical indices ---
	# (create_table_body using static iota and masking should be correct now)
	num_blocks_per_seq = jnp.ceil(seq_lengths / block_size).astype(jnp.int32)
	cum_blocks = jnp.cumsum(num_blocks_per_seq)
	start_indices = jnp.concatenate(
		[jnp.zeros(1, dtype=cum_blocks.dtype), cum_blocks[:-1]]
	)
	block_tables = jnp.zeros((batch_size, max_blocks_per_seq), dtype=jnp.int32)

	def create_table_body(i, tables):
		num_blocks = num_blocks_per_seq[i]
		start_idx = start_indices[i]
		max_iota = jax.lax.iota(dtype=jnp.int32, size=max_blocks_per_seq)
		potential_phys_indices = start_idx + max_iota
		mask = max_iota < num_blocks
		phys_indices_for_row = jnp.where(mask, potential_phys_indices, 0)
		tables_updated = tables.at[i].set(phys_indices_for_row)
		return tables_updated

	block_tables = jax.lax.fori_loop(0, batch_size, create_table_body, block_tables)

	physical_k_cache = jnp.zeros(
		(num_total_blocks, block_size, num_kv_heads, head_dim), dtype=contiguous_k.dtype
	)
	physical_v_cache = jnp.zeros(
		(num_total_blocks, block_size, num_kv_heads, head_dim), dtype=contiguous_v.dtype
	)

	def scatter_body_outer(i, caches):
		phys_k, phys_v = caches
		seq_len = seq_lengths[i]  # Traced sequence length for this item
		num_blocks = num_blocks_per_seq[i]  # Traced number of blocks for this item

		def scatter_body_inner(j, inner_caches):
			k_cache, v_cache = inner_caches
			# `j` is concrete loop index, `block_size` is static
			start_token = j * block_size

			# Get physical index (traced)
			physical_idx = block_tables[i, j]

			# --- FIX HERE: Use static slice_sizes for dynamic_slice ---
			# Define the STATIC slice size we want to extract
			static_slice_sizes = (block_size, num_kv_heads, head_dim)

			# Define the start indices (start_token is traced, which is allowed)
			slice_start_indices = (start_token, 0, 0)

			k_potential_block = jax.lax.dynamic_slice(
				contiguous_k[i],
				slice_start_indices,
				static_slice_sizes,
			)
			v_potential_block = jax.lax.dynamic_slice(
				contiguous_v[i], slice_start_indices, static_slice_sizes
			)

			block_indices = jnp.arange(block_size)
			original_indices = start_token + block_indices
			mask = original_indices < seq_len

			mask_expanded = mask[:, None, None]

			k_block_padded = jnp.where(mask_expanded, k_potential_block, 0.0)
			v_block_padded = jnp.where(mask_expanded, v_potential_block, 0.0)
			k_cache_updated = k_cache.at[physical_idx].set(k_block_padded)
			v_cache_updated = v_cache.at[physical_idx].set(v_block_padded)

			return k_cache_updated, v_cache_updated

		phys_k, phys_v = jax.lax.fori_loop(
			0,
			num_blocks,
			scatter_body_inner,
			(phys_k, phys_v),
		)
		return phys_k, phys_v

	physical_k_cache, physical_v_cache = jax.lax.fori_loop(
		0,
		batch_size,
		scatter_body_outer,
		(physical_k_cache, physical_v_cache),
	)

	return physical_k_cache, physical_v_cache, block_tables


[docs]class PagedAttention:
[docs] def build_paged_kv( self, contiguous_k: jnp.ndarray, # Shape: (batch, seq_len, num_kv_heads, head_dim) contiguous_v: jnp.ndarray, # Shape: (batch, seq_len, num_kv_heads, head_dim) seq_lengths: jnp.ndarray, # Shape: (batch,). True length of each sequence. block_size: int, num_total_blocks: int, # Desired size of the physical cache. max_blocks_per_seq: int, # Desired size of the block table per sequence. num_kv_heads: int, head_dim: int, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: return _build_paged_kv( contiguous_k, contiguous_v, seq_lengths, block_size, num_total_blocks, max_blocks_per_seq, num_kv_heads, head_dim, )
[docs] def build_contiguous_kv_vectorized( self, pages, page_indices ) -> tuple[jnp.ndarray, jnp.ndarray]: """ Builds contiguous KV caches from paged KV caches using vectorized operations. The output sequence length dimension will be max_blocks_per_seq * block_size. The caller needs external knowledge (e.g., original sequence lengths) to correctly interpret or mask the padding positions in the returned tensors. Returns: A tuple containing (contiguous_k, contiguous_v). """ return _build_contiguous_kv_vectorized(pages, page_indices)
[docs]def prefill_attention( q: jax.Array, k_pages: jax.Array, v_pages: jax.Array, length: jax.Array, page_indices: jax.Array, sm_scale: tp.Optional[float] = None, ): """Computes paged attention for the prefill phase. This function wraps the `prefill_attention_impl` Pallas kernel, handling data layout transformations and launching the kernel. It processes one chunk of the query sequence against the corresponding KV cache pages. Args: q: Query tensor for a chunk of the sequence. k_pages: Key cache stored in paged layout in HBM. v_pages: Value cache stored in paged layout in HBM. length: The total sequence length for the item being processed. page_indices: Array mapping sequence positions to page indices in k_pages/v_pages. sm_scale: normal softmax scale. By default it is None or auto. Returns: The attention output for the query chunk, shape [chunk_size, num_attn_heads * head_dim]. """ chunk_size, num_attn_heads, head_dim = q.shape num_kv_heads, _, page_size, _ = k_pages.shape assert num_attn_heads % num_kv_heads == 0 assert chunk_size % page_size == 0 attn_group_size = num_attn_heads // num_kv_heads page_per_chunk = chunk_size // page_size if sm_scale is None: sm_scale = head_dim**-0.5 q = q.transpose((1, 0, 2)) q = q * sm_scale q_block_spec = pl.BlockSpec( (attn_group_size, chunk_size, head_dim), lambda i, *_: (i, 0, 0) ) lm_block_spec = pl.BlockSpec((attn_group_size, chunk_size, 1), lambda *_: (0, 0, 0)) lm_shape = jax.ShapeDtypeStruct( shape=(attn_group_size, chunk_size, 1), dtype=jnp.float32 ) out, _, _ = pl.pallas_call( prefill_attention_impl, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=3, in_specs=[ q_block_spec, pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), ], out_specs=[ q_block_spec, lm_block_spec, lm_block_spec, ], scratch_shapes=[ pltpu.VMEM((2, page_per_chunk, page_size, head_dim), k_pages.dtype), pltpu.VMEM((2, page_per_chunk, page_size, head_dim), v_pages.dtype), pltpu.SemaphoreType.DMA, ], grid=(num_kv_heads,), ), out_shape=[ jax.ShapeDtypeStruct(q.shape, q.dtype), lm_shape, lm_shape, ], )( jnp.reshape(length, (1,)), page_indices, jnp.asarray([0], jnp.int32), q, k_pages, v_pages, ) out = out.transpose((1, 0, 2)).reshape(chunk_size, -1).astype(q.dtype) return out
[docs]@functools.partial( jax.jit, static_argnames=[ "pages_per_compute_block", "attn_logits_soft_cap", "mask_value", "megacore_mode", "inline_seq_dim", ], ) def paged_attention( q: jax.Array, k_pages: jax.Array, v_pages: jax.Array, lengths: jax.Array, page_indices: jax.Array, *, sm_scale: float = 1, mask_value: float = DEFAULT_MASK_VALUE, attn_logits_soft_cap: float | None = None, pages_per_compute_block: int, megacore_mode: str | None = None, inline_seq_dim: bool = True, ) -> jax.Array: """Paged grouped query attention. Args: q: A [batch_size, num_heads, head_dim] jax.Array. k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. lengths: A i32[batch_size] jax.Array the length of each example. page_indices: A i32[batch_size, pages_per_sequence] jax.Array. Each entry should be in the range of [0, total_num_pages), indicating where to locate the page in `k_pages` or `v_pages`. sm_scale: normal softmax scale. By default it is 1.0. mask_value: The value used for padding in attention. By default it is a very negative floating point number. attn_logits_soft_cap: The value used for soft capping the attention logits. pages_per_compute_block: how many pages to be processed in one flash attention block in the pallas kernel. megacore_mode: if set, enable megacore to parallelize the computation. Must be one of ['kv_head', 'batch', None]. Caveat: set this only if megacore is enabled, otherwise the kernel may hang. If you are not sure, leave it to None. * None: disable megacore parallelism. * kv_head: megacore parallelism on KV heads; requires number of KV heads divisible by 2. * batch: megacore parallelism on batch dimension; requires batch divisible by 2. inline_seq_dim: whether to fuse kernel instances along the sequence dim into one kernel. Returns: The output of attention([batch_size, num_heads, head_dim]). """ batch_size, num_heads, head_dim = q.shape num_kv_heads, _, page_size, head_dim_k = k_pages.shape batch_size_paged_indices, pages_per_sequence = page_indices.shape if sm_scale is None: sm_scale = head_dim**-0.5 if k_pages.shape != v_pages.shape: raise ValueError( f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and" f" {v_pages.shape}" # pytype: disable=attribute-error ) if num_heads % num_kv_heads != 0: raise ValueError( "Number of Q heads must be divisible by number of KV heads. Got" f" {num_heads} and {num_kv_heads}." ) if head_dim_k != head_dim: raise ValueError( f"head_dim of Q must be the same as that of K/V. Got {head_dim} and {head_dim_k}." ) if pages_per_sequence % pages_per_compute_block != 0: raise ValueError( "pages_per_compute_block must be divisible by pages per sequence. Got" f" {pages_per_compute_block} and {pages_per_sequence}." ) if lengths.shape != (batch_size,): raise ValueError("`lengths` and `q` must have the same batch size") if batch_size_paged_indices != batch_size: raise ValueError("`page_indices` and `q` must have the same batch size") if lengths.dtype != jnp.int32: raise ValueError("The dtype of `lengths` must be int32. Got {lengths.dtype}") if megacore_mode == "kv_head": if num_kv_heads % 2 != 0: raise ValueError( "number of KV heads must be even when megacore_mode is 'kv_head'" ) num_cores = 2 elif megacore_mode == "batch": if batch_size % 2 != 0: raise ValueError("batch size must be even when megacore_mode is 'batch'") num_cores = 2 elif megacore_mode is None: num_cores = 1 else: raise ValueError("megacore_mode must be one of ['kv_head', 'batch', None]") if (num_heads // num_kv_heads) % 8 != 0: q = q.reshape(batch_size, num_heads, 1, head_dim) if megacore_mode == "kv_head": q_block_spec = pl.BlockSpec( (None, num_heads // num_kv_heads, None, head_dim), lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0, 0), ) elif megacore_mode == "batch": q_block_spec = pl.BlockSpec( (None, num_heads // num_kv_heads, None, head_dim), lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0, 0), ) else: q_block_spec = pl.BlockSpec( (None, num_heads // num_kv_heads, None, head_dim), lambda core_index, b, h, *_: (b, h, 0, 0), ) q_dtype_for_kernel_launch = jnp.float32 else: if megacore_mode == "kv_head": q_block_spec = pl.BlockSpec( (None, num_heads // num_kv_heads, head_dim), lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0), ) elif megacore_mode == "batch": q_block_spec = pl.BlockSpec( (None, num_heads // num_kv_heads, head_dim), lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0), ) else: q_block_spec = pl.BlockSpec( (None, num_heads // num_kv_heads, head_dim), lambda core_index, b, h, *_: (b, h, 0), ) q_dtype_for_kernel_launch = q.dtype dimension_semantics: tp.Sequence[tp.Literal["parallel", "arbitrary"]] if inline_seq_dim: kernel = paged_flash_attention_kernel_inline_seq_dim grid = ( num_cores, batch_size // num_cores if megacore_mode == "batch" else batch_size, num_kv_heads // num_cores if megacore_mode == "kv_head" else num_kv_heads, ) dimension_semantics = ("parallel", "arbitrary", "arbitrary") else: kernel = paged_flash_attention_kernel grid = ( num_cores, batch_size // num_cores if megacore_mode == "batch" else batch_size, num_kv_heads // num_cores if megacore_mode == "kv_head" else num_kv_heads, pages_per_sequence // pages_per_compute_block, ) dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary") in_specs = [ q_block_spec, pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), ] scratch_shapes = ( pltpu.VMEM((2, pages_per_compute_block, page_size, head_dim), k_pages.dtype), pltpu.VMEM((2, pages_per_compute_block, page_size, head_dim), v_pages.dtype), pltpu.SemaphoreType.DMA, ) out, _, _ = pl.pallas_call( functools.partial( kernel, pages_per_sequence=pages_per_sequence, batch_size=batch_size, pages_per_compute_block=pages_per_compute_block, mask_value=mask_value, attn_logits_soft_cap=attn_logits_soft_cap, megacore_mode=megacore_mode, ), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=4, in_specs=in_specs, out_specs=[ q_block_spec, q_block_spec, q_block_spec, ], grid=grid, scratch_shapes=scratch_shapes, ), compiler_params=pltpu.TPUCompilerParams(dimension_semantics=dimension_semantics), out_shape=[ jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch), jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), ], )( lengths, page_indices.reshape(-1), jnp.zeros((1,), jnp.int32), jnp.zeros((1,), jnp.int32), q.astype(q_dtype_for_kernel_launch) * sm_scale, k_pages, v_pages, ) return out.reshape(batch_size, num_heads, head_dim).astype(q.dtype)