# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This is a copied version of
# https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention
import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
[docs]class MultiPageAsyncCopyDescriptor:
"""Manages asynchronous copies of multiple K/V pages from HBM to VMEM.
This class simplifies the process of initiating and waiting for multiple
asynchronous DMA transfers (copies) for pages belonging to the Key or Value
cache. It takes a list of page indices and orchestrates the copies into a
specified VMEM buffer.
Attributes:
_vmem_buffer: The destination VMEM buffer slice for the copies.
_num_pages_to_load: The number of pages to copy.
_pages_hbm_ref: A Pallas reference to the K or V page cache in HBM.
_sem: The semaphore used to coordinate the asynchronous copies.
_page_indices: A Pallas reference to the array containing page indices.
_page_indices_start_offset: The starting offset within `_page_indices`
for the current set of pages.
_async_copies: A list of `AsyncCopy` objects, one for each page.
"""
def __init__(
self,
pages_hbm_ref,
vmem_buffer,
sem,
page_indices,
page_indices_start_offset,
num_pages_to_load,
head_index,
):
"""Initializes the MultiPageAsyncCopyDescriptor.
Args:
pages_hbm_ref: Pallas Ref to the source K/V pages in HBM.
vmem_buffer: Pallas Ref to the destination buffer in VMEM.
sem: Pallas Ref for the semaphore to use for synchronization.
page_indices: Pallas Ref to the array holding the indices of the pages
to be loaded from HBM.
page_indices_start_offset: Starting offset in `page_indices` array.
num_pages_to_load: The number of pages to copy.
head_index: The specific head index to load pages for, if the
`pages_hbm_ref` has a head dimension. If None, assumes no head dim.
"""
self._vmem_buffer = vmem_buffer
self._num_pages_to_load = num_pages_to_load
if head_index is not None:
self._pages_hbm_ref = pages_hbm_ref.at[head_index]
else:
self._pages_hbm_ref = pages_hbm_ref
self._sem = sem
self._page_indices = page_indices
self._page_indices_start_offset = page_indices_start_offset
self._async_copies = [
self._make_async_copy(i) for i in range(self._num_pages_to_load)
]
def _make_async_copy(self, i):
"""Creates a single asynchronous copy operation for the i-th page."""
page_index = self._page_indices[self._page_indices_start_offset + i]
return pltpu.make_async_copy(
self._pages_hbm_ref.at[page_index],
self._vmem_buffer.at[i],
self._sem,
)
[docs] def start(self):
"""Starts all the configured asynchronous copy operations."""
for async_copy in self._async_copies:
async_copy.start()
[docs] def wait_and_get_loaded(self) -> jax.Array:
"""Waits for all copies to complete and returns the loaded data.
Returns:
A jax.Array containing the data loaded into the VMEM buffer, reshaped
to combine the pages along the sequence dimension. The shape will be
(num_pages_to_load * page_size, head_dim).
"""
for async_copy in self._async_copies:
async_copy.wait()
head_dim = self._vmem_buffer.shape[-1]
jax_array = self._vmem_buffer[...].astype(jnp.float32)
return jax_array.reshape(-1, head_dim)
[docs]def paged_flash_attention_kernel(
lengths_ref,
page_indices_ref,
buffer_index_ref,
step_ref,
q_ref,
k_pages_hbm_ref,
v_pages_hbm_ref,
o_ref,
m_ref,
l_ref,
k_vmem_buffer,
v_vmem_buffer,
sem,
*,
batch_size: int,
pages_per_compute_block: int,
pages_per_sequence: int,
mask_value: float,
attn_logits_soft_cap: float | None,
megacore_mode: str | None,
program_ids=(),
):
"""Pallas kernel for paged attention, likely for the decode phase.
This kernel computes attention for a single query token against paged
Key-Value caches stored in HBM. It processes the KV cache in blocks of pages,
using double buffering for asynchronous data loading and FlashAttention-style
online softmax calculation.
The kernel grid is expected to be (num_cores, batch_size // b_step,
num_kv_heads // h_step, num_blocks_per_sequence). `megacore_mode` determines
how work is distributed across cores (by batch or by KV head).
Args:
lengths_ref: SMEM Ref to sequence lengths for each batch item.
page_indices_ref: Ref to page indices mapping sequence positions to HBM pages.
buffer_index_ref: SMEM Ref storing the current VMEM buffer index (0 or 1)
for double buffering.
step_ref: SMEM Ref storing the current step/block index being processed.
q_ref: VMEM Ref to the query vector(s) for the current token.
k_pages_hbm_ref: HBM Ref to the Key cache pages.
v_pages_hbm_ref: HBM Ref to the Value cache pages.
o_ref: VMEM Ref to store the computed output attention vector(s).
m_ref: VMEM Ref to store the running maximum logit (part of online softmax).
l_ref: VMEM Ref to store the running sum of exp(logit - max_logit)
(part of online softmax).
k_vmem_buffer: VMEM Ref for the double buffer used to load Key pages.
v_vmem_buffer: VMEM Ref for the double buffer used to load Value pages.
sem: Pallas Ref for the semaphore used for async copy synchronization.
batch_size: Total batch size.
pages_per_compute_block: Number of KV cache pages processed per iteration.
pages_per_sequence: Maximum number of pages allocated per sequence.
mask_value: Value to use for masking attention logits (e.g., -inf).
attn_logits_soft_cap: If not None, apply tanh capping to logits.
megacore_mode: How to distribute work across TPU cores ('batch' or 'kv_head').
program_ids: Optional tuple to directly provide program IDs, used when
this kernel is called from another kernel (like the inline version).
"""
if program_ids:
core_index, b, h, i = program_ids
else:
core_index, b, h, i = (
pl.program_id(0),
pl.program_id(1),
pl.program_id(2),
pl.program_id(3),
)
num_kv_heads, _, page_size, _ = k_pages_hbm_ref.shape
bk = page_size * pages_per_compute_block
num_cores = pl.num_programs(0)
b_step = num_cores if megacore_mode == "batch" else 1
b_start = core_index if megacore_mode == "batch" else 0
h_step = num_cores if megacore_mode == "kv_head" else 1
h_start = core_index if megacore_mode == "kv_head" else 0
h = h * h_step + h_start
b = b * b_step + b_start
length = lengths_ref[b]
def compute_block_indices(b, h, i):
def advance_b():
next_b = b + b_step
def advance_to_next_non_zero_length():
next_next_b = next_b + b_step
return lax.fori_loop(
lax.div(next_next_b, b_step),
lax.div(batch_size, b_step),
lambda _, b: jnp.where(lengths_ref[b] == 0, b + b_step, b),
next_next_b,
)
return (
lax.cond(
jnp.logical_and(next_b < batch_size, lengths_ref[next_b] == 0),
advance_to_next_non_zero_length,
lambda: next_b,
),
h_start,
0,
)
def advance_h():
next_h = h + h_step
return lax.cond(next_h < num_kv_heads, lambda: (b, next_h, 0), advance_b)
return lax.cond(i * bk < lengths_ref[b], lambda: (b, h, i), advance_h)
def create_kv_async_copy_descriptors(b, h, i, buffer_index):
page_offset = b * pages_per_sequence + i * pages_per_compute_block
pages_to_load = pages_per_compute_block
async_copy_k = MultiPageAsyncCopyDescriptor(
k_pages_hbm_ref,
k_vmem_buffer.at[buffer_index],
sem,
page_indices_ref,
page_offset,
pages_to_load,
h,
)
async_copy_v = MultiPageAsyncCopyDescriptor(
v_pages_hbm_ref,
v_vmem_buffer.at[buffer_index],
sem,
page_indices_ref,
page_offset,
pages_to_load,
h,
)
return async_copy_k, async_copy_v
@pl.when(i * bk < length)
def flash_attention(): # pylint: disable=unused-variable
step = step_ref[0]
buffer_index = buffer_index_ref[0]
@pl.when(i == 0)
def init(): # pylint: disable=unused-variable
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
l_ref[...] = jnp.zeros_like(l_ref)
o_ref[...] = jnp.zeros_like(o_ref)
@pl.when(step == 0)
def prefetch_first_block(): # pylint: disable=unused-variable
async_copy_k, async_copy_v = create_kv_async_copy_descriptors(
b, h, i, buffer_index
)
async_copy_k.start()
async_copy_v.start()
next_b, next_h, next_i = compute_block_indices(b, h, i + 1)
@pl.when(next_b < batch_size)
def prefetch_next_block(): # pylint: disable=unused-variable
next_buffer_index = jnp.where(buffer_index == 0, 1, 0)
async_copy_next_k, async_copy_next_v = create_kv_async_copy_descriptors(
next_b, next_h, next_i, next_buffer_index
)
async_copy_next_k.start()
async_copy_next_v.start()
buffer_index_ref[0] = next_buffer_index
async_copy_k, async_copy_v = create_kv_async_copy_descriptors(b, h, i, buffer_index)
q = q_ref[...].astype(jnp.float32)
k = async_copy_k.wait_and_get_loaded()
qk = jnp.einsum("hd,td->ht", q, k, preferred_element_type=jnp.float32)
if attn_logits_soft_cap is not None:
capped_qk = jnp.tanh(qk / attn_logits_soft_cap)
qk = capped_qk * attn_logits_soft_cap
mask = i * bk + jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) < length
qk = qk + jnp.where(mask, 0.0, mask_value)
m_curr = qk.max(axis=-1)
s_curr = jnp.exp(qk - m_curr[..., None])
m_prev, l_prev = m_ref[...], l_ref[...]
l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,))
m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,))
m_next = jnp.maximum(m_prev, m_curr)
alpha = jnp.exp(m_prev - m_next)
beta = jnp.exp(m_curr - m_next)
l_next = alpha * l_prev + beta * l_curr
l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)
v = async_copy_v.wait_and_get_loaded()
o_curr_times_l_curr = jnp.dot(s_curr, v)
m_ref[...], l_ref[...] = m_next, l_next_safe
o_ref[...] = (
(l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe
).astype(o_ref.dtype)
step_ref[0] = step + 1
[docs]def paged_flash_attention_kernel_inline_seq_dim(
lengths_ref,
page_indices_ref,
buffer_index_ref,
step_ref,
q_ref,
k_pages_hbm_ref,
v_pages_hbm_ref,
o_ref,
m_ref,
l_ref,
k_vmem_buffer,
v_vmem_buffer,
sem,
*,
batch_size: int,
pages_per_compute_block: int,
pages_per_sequence: int,
mask_value: float,
attn_logits_soft_cap: float | None,
megacore_mode: str | None,
):
"""Pallas kernel for paged attention that loops over sequence blocks internally.
This kernel performs the same computation as `paged_flash_attention_kernel`
but iterates over the sequence blocks (`i`) using an internal `lax.fori_loop`
instead of having `i` as a `program_id`. The grid for this kernel is
typically (num_cores, batch_size // b_step, num_kv_heads // h_step).
Args:
lengths_ref: SMEM Ref to sequence lengths for each batch item.
page_indices_ref: Ref to page indices mapping sequence positions to HBM pages.
buffer_index_ref: SMEM Ref storing the current VMEM buffer index (0 or 1).
step_ref: SMEM Ref storing the current step/block index being processed.
q_ref: VMEM Ref to the query vector(s) for the current token.
k_pages_hbm_ref: HBM Ref to the Key cache pages.
v_pages_hbm_ref: HBM Ref to the Value cache pages.
o_ref: VMEM Ref to store the computed output attention vector(s).
m_ref: VMEM Ref to store the running maximum logit.
l_ref: VMEM Ref to store the running sum of exp(logit - max_logit).
k_vmem_buffer: VMEM Ref for the double buffer used to load Key pages.
v_vmem_buffer: VMEM Ref for the double buffer used to load Value pages.
sem: Pallas Ref for the semaphore used for async copy synchronization.
batch_size: Total batch size.
pages_per_compute_block: Number of KV cache pages processed per iteration.
pages_per_sequence: Maximum number of pages allocated per sequence.
mask_value: Value to use for masking attention logits.
attn_logits_soft_cap: If not None, apply tanh capping to logits.
megacore_mode: How to distribute work across TPU cores ('batch' or 'kv_head').
"""
core_index, b, h = pl.program_id(0), pl.program_id(1), pl.program_id(2)
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
l_ref[...] = jnp.zeros_like(l_ref)
o_ref[...] = jnp.zeros_like(o_ref)
def body(i, _):
paged_flash_attention_kernel(
lengths_ref,
page_indices_ref,
buffer_index_ref,
step_ref,
q_ref,
k_pages_hbm_ref,
v_pages_hbm_ref,
o_ref,
m_ref,
l_ref,
k_vmem_buffer,
v_vmem_buffer,
sem,
batch_size=batch_size,
pages_per_compute_block=pages_per_compute_block,
pages_per_sequence=pages_per_sequence,
mask_value=mask_value,
attn_logits_soft_cap=attn_logits_soft_cap,
megacore_mode=megacore_mode,
program_ids=(core_index, b, h, i),
)
return ()
bk = pages_per_compute_block * k_pages_hbm_ref.shape[-2]
if megacore_mode == "batch":
num_cores = pl.num_programs(0)
length = lengths_ref[b * num_cores + core_index]
else:
length = lengths_ref[b]
lax.fori_loop(0, lax.div(length + bk - 1, bk), body, ())
[docs]def prefill_attention_impl(
length_ref, # shape: (1,), smem,
page_indices_ref, # shape: (max_seq_len // page_size), smem,
buffer_index_ref, # shape: (1,), smem,
q_ref, # shape: (group_size, chunk, head_dim), vmem,
k_pages_hbm_ref, # shape: (num_kv_heads, num_pages, page_size, head_dim), hbm
v_pages_hbm_ref, # shape: (num_kv_heads, num_pages, page_size, head_dim), hbm
out_ref, # shape: (group_size, chunk, head_dim), vmem,
l_ref, # shape: (group_size, chunk, 1), vmem,
m_ref, # shape: (group_size, chunk, 1), vmem,
k_vmem_buffer, # shape: (2, page_per_chunk, page_size, head_dim), vmem,
v_vmem_buffer, # shape: (2, page_per_chunk, page_size, head_dim), vmem,
sem,
):
"""Pallas kernel implementation for paged attention prefill phase.
This kernel computes attention for a chunk of query tokens (part of the
prompt) against the paged Key-Value cache built so far. It iterates through
chunks of the KV cache, applying causal masking and using online softmax.
Double buffering is used for loading KV cache chunks.
The grid for this kernel is typically (num_kv_heads,). It processes one
query chunk for all associated attention heads within a KV head group.
Args:
length_ref: SMEM Ref containing the total sequence length of the prompt.
page_indices_ref: SMEM Ref containing the page indices for this sequence.
buffer_index_ref: SMEM Ref storing the current VMEM buffer index (0 or 1).
q_ref: VMEM Ref to the current chunk of query vectors.
k_pages_hbm_ref: HBM Ref to the Key cache pages.
v_pages_hbm_ref: HBM Ref to the Value cache pages.
out_ref: VMEM Ref to store the computed output attention vectors for the chunk.
l_ref: VMEM Ref to store the running sum part of online softmax.
m_ref: VMEM Ref to store the running max logit part of online softmax.
k_vmem_buffer: VMEM Ref for the double buffer used to load Key chunks.
v_vmem_buffer: VMEM Ref for the double buffer used to load Value chunks.
sem: Pallas Ref for the semaphore used for async copy synchronization.
"""
h = pl.program_id(0)
page_size = k_pages_hbm_ref.shape[2]
head_dim = k_pages_hbm_ref.shape[3]
group_size = q_ref.shape[0]
num_kv_heads = k_pages_hbm_ref.shape[0]
chunk_size = q_ref.shape[1]
length = length_ref[0]
q_chunk_idx = jax.lax.div(length, chunk_size)
reminder = jax.lax.rem(length, chunk_size)
q_chunk_idx -= jnp.where(reminder > 0, 0, 1)
out_ref[...] = jnp.zeros_like(out_ref)
def create_kv_async_copy_descriptors(h, i, buffer_index):
pages_to_load = chunk_size // page_size
page_offset = i * pages_to_load
async_copy_k = MultiPageAsyncCopyDescriptor(
k_pages_hbm_ref,
None,
k_vmem_buffer.at[buffer_index],
None,
sem,
page_indices_ref,
page_offset,
pages_to_load,
head_index=h,
)
async_copy_v = MultiPageAsyncCopyDescriptor(
v_pages_hbm_ref,
None,
v_vmem_buffer.at[buffer_index],
None,
sem,
page_indices_ref,
page_offset,
pages_to_load,
head_index=h,
)
return async_copy_k, async_copy_v
def next_block_indice(h, i):
return jax.lax.cond(
(i + 1) * chunk_size < length, lambda: (h, i + 1), lambda: (h + 1, 0)
)
def per_kv_chunk_body(i, _):
@pl.when((i * chunk_size) < length)
def body():
buffer_index = buffer_index_ref[0]
@pl.when(i == 0)
def init():
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
l_ref[...] = jnp.zeros_like(l_ref)
@pl.when(h == 0)
def prefetch_first_kv():
# prefetch the first kv chunk.
async_copy_k, async_copy_v = create_kv_async_copy_descriptors(
h, i, buffer_index
)
async_copy_k.start()
async_copy_v.start()
next_h, next_i = next_block_indice(h, i)
@pl.when((next_h < num_kv_heads) & (next_i <= q_chunk_idx))
def prefetch_next_block():
# prefetch the kv chunk for next iteration.
next_buffer_index = jnp.where(buffer_index == 0, 1, 0)
async_copy_next_k, async_copy_next_v = create_kv_async_copy_descriptors(
next_h, next_i, next_buffer_index
)
async_copy_next_k.start()
async_copy_next_v.start()
buffer_index_ref[0] = next_buffer_index
async_copy_k, async_copy_v = create_kv_async_copy_descriptors(h, i, buffer_index)
k = async_copy_k.wait_and_get_loaded()
v = async_copy_v.wait_and_get_loaded()
mask_shape = (chunk_size, chunk_size)
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
row_ids += q_chunk_idx * chunk_size
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
col_ids += i * chunk_size
causal_mask = col_ids <= row_ids
causal_mask_value = jnp.where(causal_mask, 0.0, DEFAULT_MASK_VALUE)
def per_group_body(group_idx, _):
q = q_ref[group_idx]
s = (
jnp.einsum("td,sd->ts", q, k, preferred_element_type=jnp.float32)
+ causal_mask_value
)
# mask.
s_max = jnp.max(s, axis=1, keepdims=True)
prev_m = m_ref[group_idx]
prev_l = l_ref[group_idx]
cur_m = jnp.maximum(prev_m, s_max)
cur_m_to_attn_size = jax.lax.broadcast_in_dim(
cur_m, (chunk_size, chunk_size), (0, 1)
)
p = jnp.exp(s - cur_m_to_attn_size)
cur_l = jnp.exp(prev_m - cur_m) * prev_l + jnp.sum(p, axis=1, keepdims=True)
out = out_ref[group_idx]
out_ref[group_idx, :, :] = (
out
* jax.lax.broadcast_in_dim(
jnp.exp(prev_m - cur_m), (chunk_size, head_dim), (0, 1)
)
+ p @ v
).astype(out_ref.dtype) # p @ v "ts,sd->td"
m_ref[group_idx, :, :] = cur_m
l_ref[group_idx, :, :] = cur_l
return ()
jax.lax.fori_loop(0, group_size, per_group_body, ())
@pl.when(((i + 1) * chunk_size) >= length)
def rescale():
out_ref[...] = (
out_ref[...]
/ jax.lax.broadcast_in_dim(
l_ref[...], (group_size, chunk_size, head_dim), (0, 1, 2)
)
).astype(out_ref.dtype)
return ()
# loop over k, v cache chunk.
jax.lax.fori_loop(
0, lax.div(length + chunk_size - 1, chunk_size), per_kv_chunk_body, ()
)