Source code for easydel.kernels.cpu_ops.ring_attention_jax._utils

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing as tp

import chex
import jax
import jax.lax as lax
from jax import numpy as jnp


def _chunk_attention_bias(
	blocksize_q: int,
	blocksize_k: int,
	bias: tp.Optional[chex.Array],
	segment_ids: tp.Optional[chex.Array],
	deterministic: bool,
	attn_dropout: tp.Optional[chex.Array],
	pdrop: float,
	blocksize_c: tp.Optional[int],
	dtype: jnp.dtype,
	query_chunk_idx: int,
	key_chunk_idx: int,
):
	"""Computes the attention bias for a chunk of the input.

	Args:
		blocksize_q: Size of query chunks.
		blocksize_k: Size of key chunks.
		bias: tp.Optional bias array of shape (batch, num_heads, q_len, kv_len).
		segment_ids: tp.Optional segment ids array of shape (batch, seq_len).
		deterministic: Whether to apply dropout.
		attn_dropout: Dropout mask.
		pdrop: Dropout probability.
		blocksize_c: Size of causal blocks.
		dtype: dtype of the computation.
		query_chunk_idx: Index of the query chunk.
		key_chunk_idx: Index of the key chunk.

	Returns:
		Attention bias for the chunk.
	"""
	query_offset = query_chunk_idx * blocksize_q
	key_offset = key_chunk_idx * blocksize_k
	chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
	if bias is not None:
		chunk_bias = lax.dynamic_slice(
			bias,
			start_indices=(0, 0, query_offset, key_offset),
			slice_sizes=(
				*bias.shape[:2],
				min(bias.shape[-2], blocksize_q),
				min(bias.shape[-1], blocksize_k),
			),
		)

	if segment_ids is not None:
		q_segment_ids = lax.dynamic_slice(
			segment_ids,
			start_indices=(0, query_offset),
			slice_sizes=(segment_ids.shape[0], blocksize_q),
		)
		k_segment_ids = lax.dynamic_slice(
			segment_ids,
			start_indices=(0, key_offset),
			slice_sizes=(segment_ids.shape[0], blocksize_k),
		)
		segment_ids_mask = ~jnp.equal(q_segment_ids[:, :, None], k_segment_ids[:, None, :])
		segment_ids_mask = segment_ids_mask[:, None]  # B1QK
		segment_ids_bias = segment_ids_mask * jnp.finfo(dtype).min
		chunk_bias = jnp.minimum(chunk_bias, segment_ids_bias)

	if blocksize_c is not None:
		query_idx = lax.broadcasted_iota(
			dtype=jnp.int32, shape=(blocksize_q, 1), dimension=0
		)
		query_idx += query_offset
		key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, blocksize_k), dimension=1)
		key_idx += key_offset
		query_idx //= blocksize_c
		key_idx //= blocksize_c
		causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
		chunk_bias = jnp.minimum(
			chunk_bias, causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
		)

	if not deterministic and pdrop > 0.0:
		attn_dropout_slice = lax.dynamic_slice(
			attn_dropout,
			start_indices=(0, 0, query_offset, key_offset),
			slice_sizes=(
				*attn_dropout.shape[:2],
				min(attn_dropout.shape[-2], blocksize_q),
				min(attn_dropout.shape[-1], blocksize_k),
			),
		)
		chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
	return chunk_bias.astype(dtype)


[docs]def below_or_on_diag( r: int, r_blk_size: int, c: int, c_blk_size: int, blocksize_c: int ): """Checks if the element at (r, c) is below or on the diagonal. Args: r: Row index. r_blk_size: Block size of the row. c: Column index. c_blk_size: Block size of the column. blocksize_c: Size of causal blocks. Returns: True if the element is below or on the diagonal, False otherwise. """ causal_block_size_q = max(blocksize_c, r_blk_size) causal_block_size_k = max(blocksize_c, c_blk_size) r = jax.lax.div(r, causal_block_size_q // r_blk_size) c = jax.lax.div(c, causal_block_size_k // c_blk_size) return ((r + 1) * causal_block_size_q - 1) > (c * causal_block_size_k)