Source code for easydel.kernels.tpu_ops.ring_attention_pallas._utils

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import typing as tp

import jax
import jax.numpy as jnp
from jax.experimental.pallas import BlockSpec

from easydel.utils import traversals as etr

INTERPRET = False
MIN_BLOCK_SIZE = 128
TRANS_B_DIM_NUMBERS = (((1,), (1,)), ((), ()))
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
NUM_LANES = 128
NUM_SUBLANES = 8


[docs]class PatchBlockSpec(BlockSpec): def __init__(self, index_map, block_shape): super().__init__(block_shape=block_shape, index_map=index_map)
[docs]class SegmentIds(tp.NamedTuple): """SegmentIds for Q and KV sequences. SegmentIds are used to generate segment mask, which prevents attention between different segments in the input sequence. Each array is a list of ids (integers). Only the token with the same id can attend to each other. Attributes: query: segment ids along the Q sequence. kv: segment ids along the KV sequence. """ query: jax.Array # [q_seq_len] kv: jax.Array # [kv_seq_len]
[docs]@etr.auto_pytree class BlockSizes: block_q: int block_k_major: int block_k: int block_b: int block_q_major_dkv: int | None = None block_k_major_dkv: int | None = None block_k_dkv: int | None = None block_q_dkv: int | None = None block_k_major_dq: int | None = None block_k_dq: int | None = None block_q_dq: int | None = None def __post_init__(self): def verify_major_minor(prefix, suffix, major, minor): if minor > major: raise ValueError( f"{prefix}{suffix}={minor} should be smaller than" f" {prefix}_major{suffix}={major}" ) if major % minor != 0: raise ValueError( f"{prefix}{suffix}={minor} should divide {prefix}_major{suffix}={major}" ) verify_major_minor("block_k", "", self.block_k_major, self.block_k) if self.block_q_major_dkv is not None and self.block_q_dkv is not None: verify_major_minor("block_q", "_dkv", self.block_q_major_dkv, self.block_q_dkv) if self.block_k_major_dkv is not None and self.block_k_dkv is not None: verify_major_minor("block_k", "_dkv", self.block_k_major_dkv, self.block_k_dkv) if self.block_k_major_dq is not None and self.block_k_dq is not None: verify_major_minor("block_k", "_dq", self.block_k_major_dq, self.block_k_dq) @property def has_backward_blocks(self) -> bool: backward_blocks = ( self.block_q_major_dkv, self.block_k_major_dkv, self.block_q_dkv, self.block_k_dkv, self.block_k_major_dq, self.block_k_dq, self.block_q_dq, ) return all(b is not None for b in backward_blocks)
[docs] @classmethod def get_default(cls, batch_size, num_heads, q_seq_len, kv_len, d_model): del batch_size, num_heads, q_seq_len, kv_len, d_model # Unused. return BlockSizes( block_q=128, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=128, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=128, block_k_major_dq=128, block_k_dq=128, block_q_dq=128, )
def _verify_block(blocksizename, dim_name, block, dim, should_divide=True): if block > dim: raise ValueError( f"{blocksizename}={block} should be smaller or equal to {dim_name}={dim}" ) if should_divide and dim % block != 0: raise ValueError(f"{dim_name}={dim} should be divisible by {blocksizename}={block}")
[docs]def below_or_on_diag( r: int, r_blk_size: int, c: int, c_blk_size: int, blocksize_c: int, ): """Checks if the element at (r, c) is below or on the diagonal. Args: r: Row index. r_blk_size: Block size of the row. c: Column index. c_blk_size: Block size of the column. blocksize_c: Size of causal blocks. Returns: True if the element is below or on the diagonal, False otherwise. """ causal_blocksize_q = max(blocksize_c, r_blk_size) causal_blocksize_k = max(blocksize_c, c_blk_size) r = jax.lax.div(r, causal_blocksize_q // r_blk_size) c = jax.lax.div(c, causal_blocksize_k // c_blk_size) return ((r + 1) * causal_blocksize_q - 1) > (c * causal_blocksize_k)