easydel.kernels.tpu_ops.ring_attention_pallas._utils#
- class easydel.kernels.tpu_ops.ring_attention_pallas._utils.BlockSizes(block_q: int, block_k_major: int, block_k: int, block_b: int, block_q_major_dkv: int | None = None, block_k_major_dkv: int | None = None, block_k_dkv: int | None = None, block_q_dkv: int | None = None, block_k_major_dq: int | None = None, block_k_dq: int | None = None, block_q_dq: int | None = None)[source]#
Bases:
object- block_b: int#
- block_k: int#
- block_k_major: int#
- block_q: int#
- property has_backward_blocks: bool#
- replace(**kwargs)#
- class easydel.kernels.tpu_ops.ring_attention_pallas._utils.PatchBlockSpec(index_map, block_shape)[source]#
Bases:
BlockSpec
- class easydel.kernels.tpu_ops.ring_attention_pallas._utils.SegmentIds(query: Array, kv: Array)[source]#
Bases:
NamedTupleSegmentIds for Q and KV sequences.
SegmentIds are used to generate segment mask, which prevents attention between different segments in the input sequence. Each array is a list of ids (integers). Only the token with the same id can attend to each other.
- easydel.kernels.tpu_ops.ring_attention_pallas._utils.below_or_on_diag(r: int, r_blk_size: int, c: int, c_blk_size: int, blocksize_c: int)[source]#
Checks if the element at (r, c) is below or on the diagonal.
- Parameters
r – Row index.
r_blk_size – Block size of the row.
c – Column index.
c_blk_size – Block size of the column.
blocksize_c – Size of causal blocks.
- Returns
True if the element is below or on the diagonal, False otherwise.