Source code for easydel.kernels.gpu_ops.flash_attention_triton._utils

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from functools import partial
import math
import typing as tp

import jax
import jax.numpy as jnp
import triton
import triton.language as tl

from .._utils import get_strides


@triton.jit
def padded_load(
	ptrs,
	offs_a,
	offs_b,
	PA0: tl.constexpr,
	PA1: tl.constexpr,
	LA0: tl.constexpr,
	LA1: tl.constexpr,
):
	if PA0:
		if PA1:
			x = tl.load(
				ptrs,
				mask=(offs_a[:, None] < LA0) & (offs_b[None, :] < LA1),
				other=0.0,
			)
		else:
			x = tl.load(
				ptrs,
				mask=offs_a[:, None] < LA0,
				other=0.0,
			)
	else:
		if PA1:
			x = tl.load(
				ptrs,
				mask=offs_b[None, :] < LA1,
				other=0.0,
			)
		else:
			x = tl.load(ptrs)
	return x


[docs]def calc_bias_strides( bias: tp.Optional[jnp.ndarray], batch: int, nheads_q: int, QSeq: int, KSeq: int, ) -> tp.Tuple[int, ...]: if bias is not None: if not hasattr(bias, "strides"): strides = tuple(map(lambda x: x * bias.itemsize, get_strides(bias))) else: strides = bias.strides if bias.shape[2] != QSeq or bias.shape[3] != KSeq: raise ValueError( f"Bias tensor has incompatible sequence dimensions. " f"Expected shape [..., {QSeq}, {KSeq}], but got [..., {bias.shape[2]}, {bias.shape[3]}]. " f"Full bias shape: {bias.shape}" ) if bias.shape[0] == 1: stride_bz = 0 elif bias.shape[0] == batch: stride_bz = strides[0] // bias.itemsize else: raise ValueError( f"Batch dimension mismatch in bias tensor. " f"Expected either 1 (for broadcasting) or {batch} (batch size), " f"but got {bias.shape[0]}. Consider reshaping your bias tensor." ) if bias.shape[1] == 1: stride_bh = 0 elif bias.shape[1] == nheads_q: stride_bh = strides[1] // bias.itemsize else: raise ValueError( f"Head dimension mismatch in bias tensor. " f"Expected either 1 (for broadcasting) or {nheads_q} (number of heads), " f"but got {bias.shape[1]}. Check that your bias tensor matches the model configuration." ) stride_bm = strides[2] // bias.itemsize else: stride_bz, stride_bh, stride_bm = 0, 0, 0 return stride_bz, stride_bh, stride_bm
[docs]@partial(jax.jit, static_argnames=["max_tokens"]) def attention_pack_with_static_shape( x: jnp.ndarray, attention_mask: jnp.ndarray, max_tokens: int = None, ) -> jnp.ndarray: """ Pack attention tensor by removing padding based on attention mask. Uses a static maximum shape to be compatible with JIT. """ batch_size, seqlen = attention_mask.shape num_heads, head_dim = x.shape[2], x.shape[3] if max_tokens is None: max_tokens = batch_size * seqlen seqlens = jnp.sum(attention_mask, axis=1).astype(jnp.int32) offsets = jnp.zeros((batch_size,), dtype=jnp.int32) offsets = offsets.at[1:].set(jnp.cumsum(seqlens[:-1])) packed = jnp.zeros((1, max_tokens, num_heads, head_dim), dtype=x.dtype) batch_idx, pos_idx = jnp.meshgrid( jnp.arange(batch_size), jnp.arange(seqlen), indexing="ij" ) batch_idx_flat = batch_idx.reshape(-1) pos_idx_flat = pos_idx.reshape(-1) valid_mask = pos_idx < seqlens[:, None] target_idx = jnp.where( valid_mask, offsets[:, None] + pos_idx, jnp.zeros_like(pos_idx), ) target_idx_flat = target_idx.reshape(-1) valid_mask_flat = valid_mask.reshape(-1) def process_token(i, packed_acc): b = batch_idx_flat[i] p = pos_idx_flat[i] t = target_idx_flat[i] valid = valid_mask_flat[i] packed_acc = jnp.where(valid, packed_acc.at[0, t].set(x[b, p]), packed_acc) return packed_acc packed = jax.lax.fori_loop(0, batch_size * seqlen, process_token, packed) return packed
[docs]@partial(jax.jit, static_argnames=["seqlen", "batch_size"]) def attention_unpack_with_static_shape( x: jnp.ndarray, cum_seqlens: jnp.ndarray, batch_size: int, seqlen: int, ) -> jnp.ndarray: """ Unpack attention tensor by redistributing the packed values to their original positions. Args: x: Packed tensor of shape [1, packed_tokens, num_heads, head_dim] cum_seqlens: Cumulative sequence lengths, shape [batch_size+1] batch_size: Number of batches seqlen: Maximum sequence length Returns: Unpacked tensor of shape [batch_size, seqlen, num_heads, head_dim] """ num_heads, head_dim = x.shape[2], x.shape[3] # Create output with static shape unpacked = jnp.zeros((batch_size, seqlen, num_heads, head_dim), dtype=x.dtype) # Process each batch def process_batch(b, unpacked_acc): start_idx = cum_seqlens[b] end_idx = cum_seqlens[b + 1] seq_len = end_idx - start_idx # Process each position in the sequence def process_position(p, acc): # Only copy if within valid sequence length valid = p < seq_len src_idx = start_idx + p # Update conditionally acc = jnp.where(valid, acc.at[b, p].set(x[0, src_idx]), acc) return acc # Process all positions in this batch unpacked_acc = jax.lax.fori_loop(0, seqlen, process_position, unpacked_acc) return unpacked_acc # Process all batches unpacked = jax.lax.fori_loop(0, batch_size, process_batch, unpacked) return unpacked
[docs]def basic_attention_refrence( q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, attn_bias: tp.Optional[jnp.ndarray] = None, query_padding_mask: tp.Optional[jnp.ndarray] = None, key_padding_mask: tp.Optional[jnp.ndarray] = None, dropout_prob: float = 0.0, dropout_key: tp.Optional[jax.random.PRNGKey] = None, window_size: tp.Tuple[int, int] = (-1, -1), causal: bool = False, softcap: float = 0.0, ): if causal: window_size = (window_size[0], 0) dtype_og = q.dtype q, k, v = q.astype(jnp.float32), k.astype(jnp.float32), v.astype(jnp.float32) QSeq, KSeq = q.shape[1], k.shape[1] repeats = q.shape[2] // k.shape[2] if repeats > 1: k = jnp.repeat(k, repeats=repeats, axis=2) v = jnp.repeat(v, repeats=repeats, axis=2) d = q.shape[-1] q_scaled = q / math.sqrt(d) scores = jnp.einsum("bthd,bshd->bhts", q_scaled, k) if softcap > 0: scores = scores / softcap scores = jnp.tanh(scores) scores = scores * softcap if key_padding_mask is not None: key_mask = (~key_padding_mask).reshape(key_padding_mask.shape[0], 1, 1, KSeq) scores = jnp.where(key_mask, jnp.finfo(scores.dtype).min, scores) if window_size[0] >= 0 or window_size[1] >= 0: row_idx = jnp.arange(QSeq).reshape(-1, 1) col_idx = jnp.arange(KSeq) if key_padding_mask is None: sk = KSeq else: sk = jnp.sum(key_padding_mask, axis=-1).reshape(-1, 1, 1, 1) if query_padding_mask is None: sq = QSeq else: sq = jnp.sum(query_padding_mask, axis=-1).reshape(-1, 1, 1, 1) if window_size[0] < 0: local_mask = col_idx > row_idx + sk - sq + window_size[1] else: if key_padding_mask is None: sk_full = jnp.full_like(col_idx, KSeq) else: sk_full = sk local_mask = jnp.logical_or( col_idx > jnp.minimum(row_idx + sk - sq + window_size[1], sk_full), col_idx < row_idx + sk - sq - window_size[0], ) scores = jnp.where(local_mask, jnp.finfo(scores.dtype).min, scores) if attn_bias is not None: scores = scores + attn_bias attention = jax.nn.softmax(scores, axis=-1).astype(v.dtype) if window_size[0] >= 0 or window_size[1] >= 0: all_masked = jnp.all(local_mask, axis=-1, keepdims=True) attention = jnp.where(all_masked, 0.0, attention) if query_padding_mask is not None: query_mask = (~query_padding_mask).reshape(query_padding_mask.shape[0], 1, QSeq, 1) attention = jnp.where(query_mask, 0.0, attention) dropout_scaling = 1.0 / (1 - dropout_prob) if dropout_prob > 0 and dropout_key is not None: dropout_mask = jax.random.bernoulli( dropout_key, p=1 - dropout_prob, shape=attention.shape ) attention_drop = attention * dropout_mask * dropout_scaling else: attention_drop = attention output = jnp.einsum("bhts,bshd->bthd", attention_drop, v) if query_padding_mask is not None: query_mask_expanded = (~query_padding_mask).reshape( query_padding_mask.shape[0], QSeq, 1, 1, ) output = jnp.where(query_mask_expanded, 0.0, output) return output.astype(dtype_og)