Source code for easydel.kernels.gpu_ops.flash_attention_triton._backward_triton

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import typing as tp

import chex
import jax
import triton
import triton.language as tl
from eformer.callib import triton_call
from jax import numpy as jnp
from triton import Config

from .._utils import (
	dtype_index,
	get_sharding,
	get_strides,
	safe_autotune,
)
from ._utils import (
	attention_pack_with_static_shape,
	attention_unpack_with_static_shape,
	calc_bias_strides,
	padded_load,
)


BIG_NEG: tl.constexpr = jnp.iinfo(jnp.int32).min
LN2: tl.constexpr = 1.44269504089


[docs]def config_prune_kernel( configs: tp.List[Config], named_args: tp.Dict[str, tp.Any], **kwargs, ) -> tp.List[Config]: kept_configs = [] for config in configs: largest_m = ( max( config.kwargs["BLOCK_M1"], config.kwargs["BLOCK_M2"], ) > named_args["QSeq"] ) largest_n = ( max( config.kwargs["BLOCK_N1"], config.kwargs["BLOCK_N2"], ) > named_args["KSeq"] ) if largest_m or largest_n: pass else: kept_configs.append(config) if kept_configs: return kept_configs return [ Config( { "BLOCK_M1": 32, "BLOCK_N1": 32, "BLOCK_M2": 32, "BLOCK_N2": 32, }, num_warps=4, num_stages=0, ) ]
@safe_autotune( configs=[ Config({"BLOCK_M": 16}, num_warps=4, num_stages=0), Config({"BLOCK_M": 32}, num_warps=4, num_stages=0), Config({"BLOCK_M": 64}, num_warps=4, num_stages=0), Config({"BLOCK_M": 128}, num_warps=4, num_stages=0), ], key=["CQSeq", "DRuntime"], ) @triton.jit def _attn_bwd_preprocess( Po, Do, stride_oz, stride_om, stride_oh, stride_dez, stride_dem, stride_deh, nheads, QSeq, max_seqlen_q_rounded, cum_seqlens_q, headdim, CQSeq, # Re-compile argument DRuntime, # Re-compile argument Delta, VARLEN: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, ): start_m = tl.program_id(0) off_zh = tl.program_id(1) off_z = off_zh // nheads off_h = off_zh % nheads offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) if VARLEN: start_seqlen_q = tl.load(cum_seqlens_q + off_z) actual_seqlen_q = tl.load(cum_seqlens_q + off_z + 1) - start_seqlen_q cu_seq_start_q = tl.load(cum_seqlens_q + off_z) off_z = 0 else: actual_seqlen_q = QSeq cu_seq_start_q = 0 o_ptrs = ( Po + off_z * stride_oz + off_h * stride_oh + cu_seq_start_q * stride_om + offs_m[:, None] * stride_om + offs_d[None, :] ) do_ptrs = ( Do + off_z * stride_dez + off_h * stride_deh + cu_seq_start_q * stride_dem + offs_m[:, None] * stride_dem + offs_d[None, :] ) mask = (offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim) o = tl.load(o_ptrs, mask=mask, other=0.0).to(tl.float32) do = tl.load(do_ptrs, mask=mask, other=0.0).to(tl.float32) delta = tl.sum(o * do, axis=1) tl.store(Delta + off_zh * max_seqlen_q_rounded + offs_m, delta) @triton.jit def _attn_bwd_dkdv( index_start_m, k, v, dk, dv, M, D, offs_m, offs_n, offs_d, q_ptrs, bias_ptrs, dropout_offs, do_ptrs, softmax_scale, stride_qm, stride_bm, stride_dom, actual_seqlen_q, actual_seqlen_k, fully_masked_lines, headdim, MASKED: tl.constexpr, IS_CAUSAL: tl.constexpr, BIAS_ON: tl.constexpr, BOOL_BIAS: tl.constexpr, USE_DROPOUT: tl.constexpr, PAD_ROWS: tl.constexpr, PAD_COLS: tl.constexpr, HEADS_PADDED: tl.constexpr, ): q_ptrs = q_ptrs + index_start_m * stride_qm do_ptrs = do_ptrs + index_start_m * stride_dom if BIAS_ON: bias_ptrs = bias_ptrs + index_start_m * stride_bm if USE_DROPOUT: dropout_offs += index_start_m * actual_seqlen_k offs_m_curr = index_start_m + offs_m q = padded_load( q_ptrs, offs_m_curr, offs_d, PAD_ROWS or HEADS_PADDED, PAD_ROWS or HEADS_PADDED, actual_seqlen_q, headdim, ) me_i = tl.load(M + offs_m_curr) if BIAS_ON: bias = padded_load( bias_ptrs, offs_m_curr, offs_n, PAD_ROWS or HEADS_PADDED, PAD_ROWS or HEADS_PADDED, actual_seqlen_q, actual_seqlen_k, ) qk = tl.dot(q, tl.trans(k)) if BIAS_ON: if BOOL_BIAS: qk = tl.where(bias, qk, BIG_NEG) else: qk += bias / softmax_scale offs_n_causal = offs_n - actual_seqlen_k + actual_seqlen_q if MASKED: if PAD_COLS: if IS_CAUSAL: qk = tl.where( tl.minimum(actual_seqlen_q - 1, offs_m_curr)[:, None] >= offs_n_causal[None, :], qk, float("-inf"), ) else: qk = tl.where(actual_seqlen_q - 1 >= offs_n_causal[None, :], qk, float("-inf")) elif IS_CAUSAL: qk = tl.where(offs_m_curr[:, None] >= offs_n_causal[None, :], qk, float("-inf")) tl.debug_barrier() p = tl.exp2(qk * (softmax_scale * LN2) - me_i[:, None]) if MASKED: if fully_masked_lines > 0: p = tl.where(offs_m_curr[:, None] < fully_masked_lines, 0, p) do = padded_load( do_ptrs, offs_m_curr, offs_d, PAD_ROWS, HEADS_PADDED, actual_seqlen_q, headdim, ) dv += tl.dot(tl.trans(p).to(do.dtype), do) dp = tl.dot(do, tl.trans(v)) Di = tl.load(D + offs_m_curr) ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) dk += tl.dot(tl.trans(ds), q) return dk, dv @triton.jit def _attn_bwd_block_dkdv( index_start_n, Q, K, V, B, Dropout, Do, Dk, Dv, M, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dkn, stride_dvn, actual_seqlen_q, actual_seqlen_k, headdim, IS_CAUSAL: tl.constexpr, BIAS_ON: tl.constexpr, BOOL_BIAS: tl.constexpr, USE_DROPOUT: tl.constexpr, PAD_COLS: tl.constexpr, HEADS_PADDED: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, ): index_begin_m = ( max(index_start_n + actual_seqlen_q - actual_seqlen_k, 0) if IS_CAUSAL else 0 ) index_begin_m = (index_begin_m // BLOCK_M) * BLOCK_M index_end_m = actual_seqlen_q fully_masked_lines = (actual_seqlen_q - actual_seqlen_k) if IS_CAUSAL else 0 if (index_begin_m >= actual_seqlen_q) or (index_start_n >= actual_seqlen_k): return offs_n = index_start_n + tl.arange(0, BLOCK_N) offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_d[None, :]) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) dk_ptrs = Dk + (offs_n[:, None] * stride_dkn + offs_d[None, :]) dv_ptrs = Dv + (offs_n[:, None] * stride_dvn + offs_d[None, :]) do_ptrs = Do + (offs_m[:, None] * stride_dom + offs_d[None, :]) if BIAS_ON: bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n[None, :]) else: bias_ptrs = None if USE_DROPOUT: dropout_offs = Dropout + offs_m[:, None] * actual_seqlen_k + offs_n[None, :] else: dropout_offs = None dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) k = padded_load( k_ptrs, offs_n, offs_d, PA0=PAD_COLS, PA1=HEADS_PADDED, LA0=actual_seqlen_k, LA1=headdim, ) v = padded_load( v_ptrs, offs_n, offs_d, PA0=PAD_COLS, PA1=HEADS_PADDED, LA0=actual_seqlen_k, LA1=headdim, ) # fmt:off fr = max(0, index_start_n + BLOCK_N - 1 + actual_seqlen_q - actual_seqlen_k) fb = BLOCK_M * ((min(fr, actual_seqlen_q) + BLOCK_M - 1) // BLOCK_M) num_masked_blocks = (fb - index_begin_m) // BLOCK_M if IS_CAUSAL else 0 index_next_start_m = index_begin_m # fmt:on if num_masked_blocks > 0: for _ in range(0, num_masked_blocks): dk, dv = _attn_bwd_dkdv( index_next_start_m, k, v, dk, dv, M, D, offs_m, offs_n, offs_d, q_ptrs, bias_ptrs, dropout_offs, do_ptrs, softmax_scale, stride_qm, stride_bm, stride_dom, actual_seqlen_q, actual_seqlen_k, fully_masked_lines, headdim, MASKED=True, IS_CAUSAL=IS_CAUSAL, BIAS_ON=BIAS_ON, BOOL_BIAS=BOOL_BIAS, USE_DROPOUT=USE_DROPOUT, PAD_ROWS=True, PAD_COLS=PAD_COLS, HEADS_PADDED=HEADS_PADDED, ) index_next_start_m += BLOCK_M if index_next_start_m < index_end_m: for index_start_m in range(index_next_start_m, index_end_m, BLOCK_M): dk, dv = _attn_bwd_dkdv( index_start_m, k, v, dk, dv, M, D, offs_m, offs_n, offs_d, q_ptrs, bias_ptrs, dropout_offs, do_ptrs, softmax_scale, stride_qm, stride_bm, stride_dom, actual_seqlen_q, actual_seqlen_k, fully_masked_lines, headdim, MASKED=False, IS_CAUSAL=IS_CAUSAL, BIAS_ON=BIAS_ON, BOOL_BIAS=BOOL_BIAS, USE_DROPOUT=USE_DROPOUT, PAD_ROWS=True, PAD_COLS=PAD_COLS, HEADS_PADDED=HEADS_PADDED, ) if HEADS_PADDED: if PAD_COLS: tl.store( dk_ptrs, dk, mask=(offs_n[:, None] < actual_seqlen_k) & (offs_d[None, :] < headdim), ) tl.store( dv_ptrs, dv, mask=(offs_n[:, None] < actual_seqlen_k) & (offs_d[None, :] < headdim), ) else: tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) else: if PAD_COLS: tl.store(dk_ptrs, dk, mask=offs_n[:, None] < actual_seqlen_k) tl.store(dv_ptrs, dv, mask=offs_n[:, None] < actual_seqlen_k) else: tl.store(dk_ptrs, dk) tl.store(dv_ptrs, dv) @triton.jit def _attn_bwd_dq( index_start_n, q, dq, do, me_i, de_i, offs_m, offs_n, offs_d, k_ptrs, v_ptrs, bias_ptrs, dropout_offs, softmax_scale, dropout_prob, dropout_seed, stride_kn, stride_vn, actual_seqlen_q, actual_seqlen_k, headdim, MASKED: tl.constexpr, IS_CAUSAL: tl.constexpr, BIAS_ON: tl.constexpr, BOOL_BIAS: tl.constexpr, USE_DROPOUT: tl.constexpr, PAD_COLS: tl.constexpr, HEADS_PADDED: tl.constexpr, ): k_ptrs = k_ptrs + index_start_n * stride_kn v_ptrs = v_ptrs + index_start_n * stride_vn offs_n_curr = index_start_n + offs_n if BIAS_ON: bias_ptrs += index_start_n if USE_DROPOUT: dropout_offs += index_start_n k = padded_load( k_ptrs, offs_n_curr, offs_d, PAD_COLS, HEADS_PADDED, actual_seqlen_k, headdim ) v = padded_load( v_ptrs, offs_n_curr, offs_d, PAD_COLS, HEADS_PADDED, actual_seqlen_k, headdim ) if BIAS_ON: bias = padded_load( bias_ptrs, offs_m, offs_n_curr, True, PAD_COLS, actual_seqlen_q, actual_seqlen_k, ) qk = tl.dot(q, tl.trans(k)) if BIAS_ON: if BOOL_BIAS: qk = tl.where(bias, qk, BIG_NEG) else: qk += bias / softmax_scale offs_n_causal = offs_n_curr - actual_seqlen_k + actual_seqlen_q if MASKED: if PAD_COLS: if IS_CAUSAL: qk = tl.where( tl.minimum(actual_seqlen_q - 1, offs_m)[:, None] >= offs_n_causal[None, :], qk, float("-inf"), ) else: qk = tl.where(actual_seqlen_q - 1 >= offs_n_causal[None, :], qk, float("-inf")) elif IS_CAUSAL: qk = tl.where(offs_m[:, None] >= offs_n_causal[None, :], qk, float("-inf")) tl.debug_barrier() p = tl.exp2(qk * (softmax_scale * 1.44269504089) - me_i[:, None]) dp = tl.dot(do, tl.trans(v)) ds = (p * (dp - de_i[:, None]) * softmax_scale).to(q.dtype) dq += tl.dot(ds, k) return dq @triton.jit def _attn_bwd_block_dq( index_start_m, Q, K, V, B, Dropout, Do, Dq, M, D, softmax_scale, dropout_prob, dropout_seed, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, actual_seqlen_q, actual_seqlen_k, headdim, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BIAS_ON: tl.constexpr, BOOL_BIAS: tl.constexpr, USE_DROPOUT: tl.constexpr, PAD_ROWS: tl.constexpr, HEADS_PADDED: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_N: tl.constexpr, ): if IS_CAUSAL: index_end_n = min( actual_seqlen_k - actual_seqlen_q + index_start_m + BLOCK_M, actual_seqlen_k, ) if index_end_n < 0: return else: index_end_n = actual_seqlen_k fully_masked_lines = actual_seqlen_q - actual_seqlen_k if IS_CAUSAL else 0 mask_reached = fully_masked_lines >= index_start_m + BLOCK_M if (index_start_m >= actual_seqlen_q) or mask_reached: return offs_m = tl.arange(0, BLOCK_M) + index_start_m offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_HEADDIM) q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_d[None, :]) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) dq_ptrs = Dq + (offs_m[:, None] * stride_dqm + offs_d[None, :]) do_ptrs = Do + (offs_m[:, None] * stride_dom + offs_d[None, :]) if BIAS_ON: bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n[None, :]) else: bias_ptrs = None if USE_DROPOUT: dropout_offs = Dropout + (offs_m[:, None] * stride_bm + offs_n[None, :]) else: dropout_offs = None dq = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) q = padded_load( q_ptrs, offs_m, offs_d, PA0=PAD_ROWS, PA1=HEADS_PADDED, LA0=actual_seqlen_q, LA1=headdim, ) do = padded_load( do_ptrs, offs_m, offs_d, PA0=PAD_ROWS, PA1=HEADS_PADDED, LA0=actual_seqlen_q, LA1=headdim, ) me_i = tl.load(M + offs_m) de_i = tl.load(D + offs_m) uneven_n = actual_seqlen_k % BLOCK_N != 0 attention_padding = VARLEN & uneven_n if IS_CAUSAL: first_masked_col = index_start_m + 1 + actual_seqlen_k - actual_seqlen_q elif attention_padding: first_masked_col = actual_seqlen_k else: first_masked_col = index_end_n nb_full_blocks = first_masked_col // BLOCK_N index_next_start_n = 0 if nb_full_blocks > 0: for _ in range(0, nb_full_blocks): index_next_start_n = tl.multiple_of(index_next_start_n, BLOCK_N) dq = _attn_bwd_dq( index_next_start_n, q, dq, do, me_i, de_i, offs_m, offs_n, offs_d, k_ptrs, v_ptrs, bias_ptrs, dropout_offs, softmax_scale, dropout_prob, dropout_seed, stride_kn, stride_vn, actual_seqlen_q, actual_seqlen_k, headdim, IS_CAUSAL=IS_CAUSAL, BIAS_ON=BIAS_ON, BOOL_BIAS=BOOL_BIAS, USE_DROPOUT=USE_DROPOUT, MASKED=False, PAD_COLS=False, HEADS_PADDED=HEADS_PADDED, ) index_next_start_n += BLOCK_N if index_next_start_n < index_end_n: for index_start_n in range(index_next_start_n, index_end_n, BLOCK_N): pad_cols = (not EVEN_N) or ( VARLEN and (index_start_n + BLOCK_N > actual_seqlen_k) ) dq = _attn_bwd_dq( index_start_n, q, dq, do, me_i, de_i, offs_m, offs_n, offs_d, k_ptrs, v_ptrs, bias_ptrs, dropout_offs, softmax_scale, dropout_prob, dropout_seed, stride_kn, stride_vn, actual_seqlen_q, actual_seqlen_k, headdim, IS_CAUSAL=IS_CAUSAL, BIAS_ON=BIAS_ON, BOOL_BIAS=BOOL_BIAS, USE_DROPOUT=USE_DROPOUT, MASKED=True, PAD_COLS=pad_cols, HEADS_PADDED=HEADS_PADDED, ) if fully_masked_lines > 0: dq = tl.where(offs_m[:, None] < fully_masked_lines, 0, dq) if HEADS_PADDED: if PAD_ROWS: tl.store( dq_ptrs, dq, mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim), ) else: tl.store(dq_ptrs, dq, mask=offs_d[None, :] < headdim) else: if PAD_ROWS: tl.store(dq_ptrs, dq, mask=offs_m[:, None] < actual_seqlen_q) else: tl.store(dq_ptrs, dq) @safe_autotune( configs=[ Config( {"BLOCK_M1": 16, "BLOCK_N1": 16, "BLOCK_M2": 16, "BLOCK_N2": 16}, num_warps=4, num_stages=0, ), Config( {"BLOCK_M1": 32, "BLOCK_N1": 16, "BLOCK_M2": 16, "BLOCK_N2": 32}, num_warps=4, num_stages=0, ), Config( {"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 32}, num_warps=4, num_stages=0, ), Config( {"BLOCK_M1": 64, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64}, num_warps=4, num_stages=0, ), Config( {"BLOCK_M1": 64, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64}, num_warps=4, num_stages=0, ), ], key=[ "CQSeq", "CKSeq", "DRuntime", "VARLEN", "USE_DROPOUT", "IS_CAUSAL", "BIAS_ON", "BLOCK_HEADDIM", ], prune_configs_by={"early_config_prune": config_prune_kernel}, ) @triton.heuristics( { "EVEN_M1": lambda args: args["QSeq"] % args["BLOCK_M1"] == 0, "EVEN_N1": lambda args: args["KSeq"] % args["BLOCK_N1"] == 0, "EVEN_M2": lambda args: args["QSeq"] % args["BLOCK_M2"] == 0, "EVEN_N2": lambda args: args["KSeq"] % args["BLOCK_N2"] == 0, "HEADS_PADDED": lambda args: args["headdim"] != args["BLOCK_HEADDIM"], "NUM_BLOCKS_KV": lambda args: math.ceil(args["KSeq"] / args["BLOCK_N1"]), } ) @triton.jit def _attn_bwd( Q, K, V, B, Do, M, D, softmax_scale, dropout_prob, dropout_seed, stride_qz, stride_qm, stride_qh, stride_kz, stride_kn, stride_kh, stride_vz, stride_vn, stride_vh, stride_bz, stride_bm, stride_bh, stride_doz, stride_dom, stride_doh, stride_dqz, stride_dqm, stride_dqh, stride_dkz, stride_dkn, stride_dkh, stride_dvz, stride_dvn, stride_dvh, nheads_q, num_repeats, QSeq, cum_seqlens_q, KSeq, cum_seqlens_k, seqlen_q_rounded, headdim, CQSeq, CKSeq, DRuntime, Dq, Dk, Dv, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BIAS_ON: tl.constexpr, BOOL_BIAS: tl.constexpr, USE_DROPOUT: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, # Heuristics EVEN_M1: tl.constexpr, EVEN_N1: tl.constexpr, EVEN_M2: tl.constexpr, EVEN_N2: tl.constexpr, NUM_BLOCKS_KV: tl.constexpr, HEADS_PADDED: tl.constexpr, # AutoTune BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, ): pid = tl.program_id(0) off_zh = tl.program_id(1) off_z = off_zh // nheads_q off_head_q = off_zh % nheads_q off_head_kv = off_head_q // num_repeats if VARLEN: cu_seq_start_q = tl.load(cum_seqlens_q + off_z) cu_seq_start_k = tl.load(cum_seqlens_k + off_z) actual_seqlen_q = tl.load(cum_seqlens_q + off_z + 1) - cu_seq_start_q actual_seqlen_k = tl.load(cum_seqlens_k + off_z + 1) - cu_seq_start_k off_z = 0 else: cu_seq_start_q = 0 cu_seq_start_k = 0 actual_seqlen_q = QSeq actual_seqlen_k = KSeq Q += off_z * stride_qz + off_head_q * stride_qh + cu_seq_start_q * stride_qm K += off_z * stride_kz + off_head_kv * stride_kh + cu_seq_start_k * stride_kn V += off_z * stride_vz + off_head_kv * stride_vh + cu_seq_start_k * stride_vn Do += off_z * stride_doz + off_head_q * stride_doh + cu_seq_start_q * stride_dom Dq += off_z * stride_dqz + off_head_q * stride_dqh + cu_seq_start_q * stride_dqm Dk += off_z * stride_dkz + off_head_q * stride_dkh + cu_seq_start_k * stride_dkn Dv += off_z * stride_dvz + off_head_q * stride_dvh + cu_seq_start_k * stride_dvn if BIAS_ON: B += off_z * stride_bz + off_head_q * stride_bh + cu_seq_start_q * stride_bm if USE_DROPOUT: Dropout = actual_seqlen_k * ( cu_seq_start_q + actual_seqlen_q * (off_head_q + nheads_q * off_z) ) else: Dropout = None D += off_zh * seqlen_q_rounded M += off_zh * seqlen_q_rounded if pid < NUM_BLOCKS_KV: i_start_n = pid pad_cols = (not EVEN_N1) or ( VARLEN and ((i_start_n + 1) * BLOCK_N1 > actual_seqlen_k) ) _attn_bwd_block_dkdv( i_start_n * BLOCK_N1, Q, K, V, B, Dropout, Do, Dk, Dv, M, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dkn, stride_dvn, actual_seqlen_q, actual_seqlen_k, headdim, IS_CAUSAL=IS_CAUSAL, BIAS_ON=BIAS_ON, BOOL_BIAS=BOOL_BIAS, USE_DROPOUT=USE_DROPOUT, PAD_COLS=pad_cols, HEADS_PADDED=HEADS_PADDED, BLOCK_M=BLOCK_M1, BLOCK_N=BLOCK_N1, BLOCK_HEADDIM=BLOCK_HEADDIM, ) else: i_start_m = pid - NUM_BLOCKS_KV pad_rows = (not EVEN_M2) or ( VARLEN and ((i_start_m + 1) * BLOCK_M2 > actual_seqlen_q) ) _attn_bwd_block_dq( i_start_m * BLOCK_M2, Q, K, V, B, Dropout, Do, Dq, M, D, softmax_scale, dropout_prob, dropout_seed, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, actual_seqlen_q, actual_seqlen_k, headdim, VARLEN=VARLEN, IS_CAUSAL=IS_CAUSAL, BIAS_ON=BIAS_ON, BOOL_BIAS=BOOL_BIAS, USE_DROPOUT=USE_DROPOUT, PAD_ROWS=pad_rows, HEADS_PADDED=HEADS_PADDED, BLOCK_M=BLOCK_M2, BLOCK_N=BLOCK_N2, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_N=EVEN_N2, ) def _bwd_attention_kernel_call( dO: chex.Array, q: chex.Array, k: chex.Array, v: chex.Array, bias: tp.Optional[chex.Array], attention_mask: tp.Optional[chex.Array], o: chex.Array, M: chex.Array, dropout_prob: float, causal: bool, softmax_scale: tp.Optional[float], dropout_seed: tp.Optional[int], varlen_mode: bool, ): """Calls the Triton kernel for the backward pass of the attention mechanism. Args: softmax_scale: Scaling factor for the softmax function. residual: Residual from the forward pass. Do: Output gradient array. Returns: Tuple of the gradients of the query, key, value, and bias arrays. """ if attention_mask is not None and varlen_mode: assert bias is None, ( "Attention mask is not supported along with attention bias. Just use bias instead." ) assert q.shape[1] == k.shape[1], "Attention mask is not supported with QSeq != KSeq" varlen_mode = attention_mask.shape[0] > 1 useless_padding = attention_mask.shape[1] - attention_mask.sum(-1).max().item() if useless_padding > 0: dO = dO[:, :-useless_padding] q = q[:, :-useless_padding] k = k[:, :-useless_padding] v = v[:, :-useless_padding] attention_mask = attention_mask[:, :-useless_padding] o = o[:, :-useless_padding] else: varlen_mode = False useless_padding = 0 batch_size, QSeq, nheads_q, head_dim = q.shape _, KSeq, nheads_kv, _ = k.shape max_seqlen_q_rounded = math.ceil(QSeq / 128) * 128 softmax_scale = 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale assert nheads_q % nheads_kv == 0, f"{nheads_q = } is not divisible by {nheads_kv =}" assert M.shape == (batch_size, nheads_q, max_seqlen_q_rounded) BOOL_BIAS = False if not varlen_mode and attention_mask is not None: assert bias is None, "when using attention mask (bool) you can't use bias" BOOL_BIAS = True bias = jnp.astype(attention_mask, jnp.bool) if varlen_mode: cum_seqlens_q = jnp.zeros(shape=(attention_mask.shape[0] + 1,), dtype="i4") cum_seqlens_k = jnp.zeros(shape=(attention_mask.shape[0] + 1,), dtype="i4") cum_seqlens_k = cum_seqlens_k.at[1:].set( jnp.cumsum( attention_mask.sum(axis=1, dtype="i4"), axis=0, dtype="i4", ) ) cum_seqlens_q = cum_seqlens_q.at[1:].set( jnp.cumsum( attention_mask.sum(axis=1, dtype="i4"), axis=0, dtype="i4", ) ) max_seqlen_q: int = attention_mask.shape[1] max_seqlen_k: int = attention_mask.shape[1] dO = attention_pack_with_static_shape(dO, attention_mask) q = attention_pack_with_static_shape(q, attention_mask) k = attention_pack_with_static_shape(k, attention_mask) v = attention_pack_with_static_shape(v, attention_mask) o = attention_pack_with_static_shape(o, attention_mask) QSeq = q.shape[1] KSeq = k.shape[1] else: cum_seqlens_q = None cum_seqlens_k = None max_seqlen_q = QSeq max_seqlen_k = KSeq bz, bh, bm = calc_bias_strides( bias, batch_size, nheads_q, QSeq, KSeq, ) num_repeats = nheads_q // nheads_kv BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) oz, om, oh, _ = get_strides(o) doz, dom, doh, _ = get_strides(dO) qz, qm, qh, _ = get_strides(q) kz, kn, kh, _ = get_strides(k) vz, vn, vh, _ = get_strides(v) (delta,) = triton_call( o, dO, oz, om, oh, doz, dom, doh, nheads_q, QSeq, max_seqlen_q_rounded, cum_seqlens_q if cum_seqlens_q is not None else jnp.array((1,), dtype=jnp.float16), head_dim, max_seqlen_q // 32, dtype_index(q), VARLEN=varlen_mode, BLOCK_HEADDIM=BLOCK_HEADDIM, out_shape=[ jax.ShapeDtypeStruct( shape=M.shape, dtype="f4", sharding=get_sharding(M), ) ], grid=lambda META: ( triton.cdiv(max_seqlen_q, META["BLOCK_M"]), batch_size * nheads_q, ), kernel=_attn_bwd_preprocess, name="triton::ops::_attn_bwd_preprocess", ) dq, dk, dv = triton_call( q, k, v, bias if bias is not None else jnp.zeros((1,), jnp.float16), dO, M, delta, softmax_scale, dropout_prob, dropout_seed if dropout_seed is not None else jnp.zeros((1,), jnp.float16), qz, qm, qh, kz, kn, kh, vz, vn, vh, bz, bm, bh, doz, dom, doh, qz, qm, qh, kz, kn, kh, vz, vn, vh, nheads_q, num_repeats, QSeq, cum_seqlens_q if cum_seqlens_q is not None else jnp.zeros((1,), jnp.float16), KSeq, cum_seqlens_k if cum_seqlens_k is not None else jnp.zeros((1,), jnp.float16), max_seqlen_q_rounded, head_dim, max_seqlen_q // 32, max_seqlen_k // 32, dtype_index(q), BIAS_ON=(bias is not None), VARLEN=varlen_mode, IS_CAUSAL=causal, USE_DROPOUT=(dropout_prob > 0), BLOCK_HEADDIM=BLOCK_HEADDIM, BOOL_BIAS=BOOL_BIAS, kernel=_attn_bwd, grid=lambda META: ( triton.cdiv(KSeq, META["BLOCK_N1"]) + triton.cdiv(QSeq, META["BLOCK_M2"]), batch_size * nheads_q, ), out_shape=[ jax.ShapeDtypeStruct( shape=q.shape, dtype="f4", sharding=get_sharding(q), ), jax.ShapeDtypeStruct( shape=(k.shape[0], k.shape[1], q.shape[2], k.shape[3]), dtype=k.dtype, ), jax.ShapeDtypeStruct( shape=(v.shape[0], v.shape[1], q.shape[2], v.shape[3]), dtype=v.dtype, ), ], name="triton::ops::_attn_bwd", ) if num_repeats > 1: dk = dk.reshape(dk.shape[0], dk.shape[1], nheads_kv, num_repeats, -1) dk = jnp.sum(dk, axis=3) dv = dv.reshape(dv.shape[0], dv.shape[1], nheads_kv, num_repeats, -1) dv = jnp.sum(dv, axis=3) if varlen_mode: dq = attention_unpack_with_static_shape(dq, cum_seqlens_q, batch_size, max_seqlen_q) dk = attention_unpack_with_static_shape(dk, cum_seqlens_k, batch_size, max_seqlen_k) dv = attention_unpack_with_static_shape(dv, cum_seqlens_k, batch_size, max_seqlen_k) if useless_padding > 0: dq = jnp.pad(dq, ((0, useless_padding), (0, 0), (0, 0))) dk = jnp.pad(dk, ((0, useless_padding), (0, 0), (0, 0))) dv = jnp.pad(dv, ((0, useless_padding), (0, 0), (0, 0))) return dq, dk, dv