easydel.kernels.tpu_ops.ring_attention_pallas._utils#

class easydel.kernels.tpu_ops.ring_attention_pallas._utils.BlockSizes(block_q: int, block_k_major: int, block_k: int, block_b: int, block_q_major_dkv: int | None = None, block_k_major_dkv: int | None = None, block_k_dkv: int | None = None, block_q_dkv: int | None = None, block_k_major_dq: int | None = None, block_k_dq: int | None = None, block_q_dq: int | None = None)[source]#

Bases: object

block_b: int#
block_k: int#
block_k_dkv: int | None = None#
block_k_dq: int | None = None#
block_k_major: int#
block_k_major_dkv: int | None = None#
block_k_major_dq: int | None = None#
block_q: int#
block_q_dkv: int | None = None#
block_q_dq: int | None = None#
block_q_major_dkv: int | None = None#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

classmethod get_default(batch_size, num_heads, q_seq_len, kv_len, d_model)[source]#
property has_backward_blocks: bool#
replace(**kwargs)#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

class easydel.kernels.tpu_ops.ring_attention_pallas._utils.PatchBlockSpec(index_map, block_shape)[source]#

Bases: BlockSpec

class easydel.kernels.tpu_ops.ring_attention_pallas._utils.SegmentIds(query: Array, kv: Array)[source]#

Bases: NamedTuple

SegmentIds for Q and KV sequences.

SegmentIds are used to generate segment mask, which prevents attention between different segments in the input sequence. Each array is a list of ids (integers). Only the token with the same id can attend to each other.

query#

segment ids along the Q sequence.

Type

jax.Array

kv#

segment ids along the KV sequence.

Type

jax.Array

kv: Array#

Alias for field number 1

query: Array#

Alias for field number 0

easydel.kernels.tpu_ops.ring_attention_pallas._utils.below_or_on_diag(r: int, r_blk_size: int, c: int, c_blk_size: int, blocksize_c: int)[source]#

Checks if the element at (r, c) is below or on the diagonal.

Parameters
  • r – Row index.

  • r_blk_size – Block size of the row.

  • c – Column index.

  • c_blk_size – Block size of the column.

  • blocksize_c – Size of causal blocks.

Returns

True if the element is below or on the diagonal, False otherwise.