Source code for easydel.layers.attention_operator.modules.ring

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import typing as tp
from functools import partial

import jax
import jax.lax as lax
from eformer.escale import with_sharding_constraint
from einops import rearrange
from jax import Array
from jax import numpy as jnp
from jax import random as jr
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as Ps

from easydel.kernels.tpu_ops import pallas_ring_attention

from .._attention_impl import (
	AttentionImpl,
	AttentionMetadata,
	AttentionOutput,
	AttentionRegistry,
)
from .vanilla import VanillaAttn


[docs]def blockwise_attn( query, key, value, bias=None, deterministic=True, dropout_rng=None, attn_pdrop=0.0, causal=True, query_chunk_size=2048, key_chunk_size=2048, dtype=jnp.float32, policy=jax.checkpoint_policies.nothing_saveable(), # noqa: B008 precision=None, float32_logits=True, prevent_cse=True, ): query = query / jnp.sqrt(query.shape[-1]).astype(dtype) if float32_logits: query = query.astype(jnp.float32) key = key.astype(jnp.float32) batch, q_len, num_heads, dim_per_head = query.shape batch, kv_len, num_heads, dim_per_head = key.shape batch, kv_len, num_heads, dim_per_head = value.shape num_q = q_len // query_chunk_size num_kv = kv_len // key_chunk_size query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head)) key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) query = jnp.moveaxis(query, 1, 0) key = jnp.moveaxis(key, 1, 0) value = jnp.moveaxis(value, 1, 0) if bias is not None: for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)): assert bias_dim == 1 or bias_dim == broadcast_dim if not deterministic and attn_pdrop > 0.0: attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng) attn_dropout = jax.random.bernoulli( attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len) ) else: attn_dropout = None _chunk_bias_fn = functools.partial( _chunk_attention_bias, query_chunk_size, key_chunk_size, bias, deterministic, attn_dropout, attn_pdrop, causal, dtype, ) def scan_attention(args): query_chunk, query_chunk_idx = args @functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy) def scan_kv_block(carry, args): key_chunk, value_chunk, key_chunk_idx = args (numerator, denominator, prev_max_score) = carry attn_weights = jnp.einsum( "bqhd,bkhd->bqhk", query_chunk, key_chunk, precision=precision ) bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx) bias_chunk = jnp.moveaxis(bias_chunk, 1, 2) attn_weights = attn_weights + bias_chunk max_score = jnp.max(attn_weights, axis=-1, keepdims=True) max_score = jnp.maximum(prev_max_score, max_score) max_score = jax.lax.stop_gradient(max_score) exp_weights = jnp.exp(attn_weights - max_score) exp_values = jnp.einsum( "bqhv,bvhd->bqhd", exp_weights, value_chunk, precision=precision ) correction = jnp.exp(prev_max_score - max_score) numerator = numerator * correction + exp_values denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True) return Carry(numerator, denominator, max_score), None def skip_upper_half(carry, args): key_chunk, value_chunk, key_chunk_idx = args skip_block = jnp.array(False) if causal: skip_block = query_chunk_idx < key_chunk_idx return jax.lax.cond( skip_block, lambda carry, args: (carry, None), scan_kv_block, carry, args, ) init_carry = Carry( jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), (-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype), ) (numerator, denominator, max_score), _ = lax.scan( skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv)) ) outputs = (numerator / denominator).astype(dtype) return outputs _, res = lax.scan( lambda _, x: ((), scan_attention(x)), (), xs=(query, jnp.arange(0, num_q)) ) res = rearrange(res, "n b c h d -> b (n c) h d") return res
[docs]class Carry(tp.NamedTuple): numerator: jax.Array denominator: jax.Array max_so_far: jax.Array
def _chunk_attention_bias( query_chunk_size, key_chunk_size, bias, deterministic, attn_dropout, attn_pdrop, causal, dtype, query_chunk_idx, key_chunk_idx, ): query_offset = query_chunk_idx * query_chunk_size key_offset = key_chunk_idx * key_chunk_size chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype) if bias is not None: chunk_bias = lax.dynamic_slice( bias, start_indices=(0, 0, query_offset, key_offset), slice_sizes=( *bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size), ), ) if causal: query_idx = lax.broadcasted_iota( dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0 ) key_idx = lax.broadcasted_iota( dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1 ) offset = query_offset - key_offset query_idx += offset causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape) if not deterministic and attn_pdrop > 0.0: attn_dropout_slice = lax.dynamic_slice( attn_dropout, start_indices=(0, 0, query_offset, key_offset), slice_sizes=( *attn_dropout.shape[:2], min(attn_dropout.shape[-2], query_chunk_size), min(attn_dropout.shape[-1], key_chunk_size), ), ) chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min return chunk_bias.astype(dtype) # TODO: Recheck this
[docs]@AttentionRegistry.register class RingAttn(AttentionImpl): """ Attention implementation using ring-passing algorithm or blockwise scan. This implementation supports: - Native (scan-based) blockwise attention via `blockwise_attn`. - TPU-specific ring attention using `pallas_ring_attention` kernel. It is registered under the name "ring". """
[docs] @classmethod def get_impl_name(cls) -> tp.Union[str, tp.Tuple[str]]: """Returns the registered name: "ring".""" return "ring"
[docs] def get_impl_metadata(self) -> AttentionMetadata: """Returns the metadata associated with this instance.""" return self.metadata
[docs] @jax.named_scope("easydel-ringimpl-native-xla") def forward_native( self, q: Array, k: Array, v: Array, mask: tp.Optional[Array] = None, bias: tp.Optional[Array] = None, init_bias: tp.Optional[tp.Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: tp.Optional[jr.PRNGKey] = None, causal: bool = True, **ignore, ) -> AttentionOutput: """ Computes attention using the scan-based `blockwise_attn` function. Handles optional mask/bias, KV head repetition, and sharding constraints. Args: q: Query tensor (B, T, H, D). k: Key tensor (B, S, H_kv, D). v: Value tensor (B, S, H_kv, D). mask: Optional boolean attention mask (broadcastable to B, 1, T, S). bias: Optional attention bias (broadcastable to B, H, T, S). init_bias: Optional callable to initialize bias if mask/bias are None. deterministic: If False, enables dropout. Requires `dropout_rng`. dropout_rng: JAX PRNG key for dropout if `deterministic` is False. causal: Apply causal mask if True. **ignore: Ignored keyword arguments. Returns: AttentionOutput containing the attention result. """ sm_scale = self.metadata.softmax_scale sm_scale = sm_scale if sm_scale is not None else q.shape[-1] ** -0.5 dtype = self.metadata.runtime_dtype k, v = self.repeat_kv_heads(k, v, q.shape[2] // k.shape[2]) query_lenght = q.shape[1] value_lenght = v.shape[1] runtime_type = self.get_runtime_type(q=q, BTHD=True) ( query_partition_spec, key_partition_spec, value_partition_spec, bias_partition_spec, mask_partition_spec, attention_partition_spec, ) = self.metadata.get_partition_specs(runtime_type, True) blocksize_k = min(self.metadata.blocksize_k, value_lenght) blocksize_q = min(self.metadata.blocksize_q, query_lenght) with self.metadata.mesh: if mask is None and bias is None and init_bias is not None: bias = init_bias() if bias is None and mask is not None: bias = jnp.where(mask, 0, jnp.finfo(dtype).min) output = with_sharding_constraint( arr=blockwise_attn( query=with_sharding_constraint(arr=q, sharding=query_partition_spec), key=with_sharding_constraint(arr=k, sharding=key_partition_spec), value=with_sharding_constraint(arr=v, sharding=value_partition_spec), bias=with_sharding_constraint(arr=bias, sharding=bias_partition_spec), deterministic=deterministic, dtype=dtype, dropout_rng=dropout_rng, precision=jax.lax.Precision.DEFAULT, attn_pdrop=self.metadata.dropout_prob, key_chunk_size=blocksize_k, query_chunk_size=blocksize_q, prevent_cse=False, causal=causal, float32_logits=True, ), sharding=attention_partition_spec, ) return AttentionOutput(attention_weights=None, attention_outputs=output)
[docs] def forward_gpu(self, *args, **kwargs) -> AttentionOutput: """GPU forward pass. Currently delegates to `forward_native` (scan-based).""" # TODO: Implement GPU-specific ring attention kernel if available return self.forward_cuda(*args, **kwargs)
[docs] @jax.named_scope("easydel-ringimpl-tpu") def forward_tpu( self, q: Array, k: Array, v: Array, mask: tp.Optional[Array] = None, bias: tp.Optional[Array] = None, init_bias: tp.Optional[tp.Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: tp.Optional[jr.PRNGKey] = None, causal: bool = True, **ignore, ) -> AttentionOutput: """ Computes Ring Attention on TPU using the `pallas_ring_attention` kernel. Handles optional mask/bias, sharding, and passes configuration to the kernel. Args: q: Query tensor (B, T, H, D). k: Key tensor (B, S, H_kv, D). v: Value tensor (B, S, H_kv, D). mask: Optional boolean attention mask (broadcastable to B, 1, T, S). bias: Optional attention bias (broadcastable to B, H, T, S). init_bias: Optional callable to initialize bias if mask/bias are None. deterministic: If False, potentially enables dropout within the kernel (if supported). dropout_rng: JAX PRNG key (may be used by the kernel if dropout is enabled). causal: Apply causal mask if True. Passed to the kernel. **ignore: Ignored keyword arguments. Returns: AttentionOutput containing the attention result. """ sm_scale = self.metadata.softmax_scale sm_scale = sm_scale if sm_scale is not None else q.shape[-1] ** -0.5 dtype = self.metadata.runtime_dtype runtime_type = self.get_runtime_type(q=q, BTHD=True) ( query_partition_spec, key_partition_spec, value_partition_spec, bias_partition_spec, mask_partition_spec, attention_partition_spec, ) = self.metadata.get_partition_specs(runtime_type, True) if mask is None and bias is None and init_bias is not None: bias = init_bias() segment_ids = None if bias is None and mask is not None: bias = jnp.where(mask, 0, jnp.finfo(dtype).min) blocksize_k = min(self.metadata.blocksize_k, k.shape[1]) blocksize_q = min(self.metadata.blocksize_q, q.shape[1]) attn_output = shard_map( partial( pallas_ring_attention, axis_name=self.metadata.sequence_axis_name, float32_logits=True, cache_idx=None, query_chunk_size=blocksize_q, key_chunk_size=blocksize_k, causal_block_size=1 if causal else None, ), in_specs=( self.create_stable_sharding(query_partition_spec, dep=q), self.create_stable_sharding(key_partition_spec, dep=k), self.create_stable_sharding(value_partition_spec, dep=v), self.create_stable_sharding(bias_partition_spec, [0], dep=b), self.create_stable_sharding(Ps(query_partition_spec[0], None), dep=segment_ids), ), out_specs=self.create_stable_sharding(attention_partition_spec), mesh=self.metadata.mesh, check_rep=False, )( q.astype(dtype), k.astype(dtype), v.astype(dtype), bias, segment_ids, ) return AttentionOutput( attention_weights=None, attention_outputs=with_sharding_constraint( arr=attn_output, sharding=attention_partition_spec, ), )
[docs] def forward_cpu(self, *args, **kwargs) -> AttentionOutput: """CPU forward pass. Delegates to `forward_native` (scan-based).""" return self.forward_native(*args, **kwargs)
[docs] def forward_cuda(self, *args, **kwargs) -> AttentionOutput: """CUDA GPU forward pass. Currently delegates to `forward_native` (scan-based).""" # TODO: Implement GPU-specific ring attention kernel if available return self.forward_native(*args, **kwargs)
[docs] def forward_rocm(self, *args, **kwargs) -> AttentionOutput: """ROCm GPU forward pass. Currently delegates to `forward_native` (scan-based).""" # TODO: Implement ROCm-specific ring attention kernel if available return self.forward_native(*args, **kwargs)
def __call__( self, q: Array, k: Array, v: Array, mask: tp.Optional[Array] = None, bias: tp.Optional[Array] = None, init_bias: tp.Optional[tp.Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: tp.Optional[jr.PRNGKey] = None, causal: bool = True, **ignore, ) -> AttentionOutput: """ Executes the Ring Attention computation. Currently bypasses the backend dispatch and directly calls `forward_native`. (See TODO in the original code). Args: q: Query tensor. k: Key tensor. v: Value tensor. mask: Optional attention mask. bias: Optional attention bias. init_bias: Optional callable to initialize bias. deterministic: If False, enables dropout (requires dropout_rng). dropout_rng: JAX PRNG key for dropout if deterministic is False. causal: Apply causal mask if True. **ignore: Additional ignored keyword arguments. Returns: An `AttentionOutput` object containing the results. """ # TODO: Debug Ring Attention then restore super().__call__ dispatch. # The original code temporarily forces native execution. return self.forward_native( q=q, k=k, v=v, mask=mask, bias=bias, init_bias=init_bias, deterministic=deterministic, dropout_rng=dropout_rng, causal=causal, )
if __name__ == "__main__": from easydel.infra import EasyDeLBaseConfig b, qs, ks, qh, kh, d, vd = 1, 1024, 1024, 32, 32, 128, 128 q = jr.normal(jr.key(0), (b, qs, qh, d), "f2") k = jr.normal(jr.key(1), (b, ks, kh, d), "f2") v = jr.normal(jr.key(2), (b, ks, kh, vd), "f2") cu_mask = VanillaAttn._create_causal_mask(qs)[None, None, :, :].repeat(b, 0) metadata = AttentionMetadata( runtime_dtype=jnp.bfloat16, base_config=EasyDeLBaseConfig(axis_dims=(1, 1, 1, -1)), blocksize_k=128, blocksize_q=128, backend="cpu", ) vanilla = VanillaAttn(metadata) attn = RingAttn(metadata) out = attn( q=q, k=k, v=v, mask=cu_mask, ).attention_outputs vout = vanilla( q=q, k=k, v=v, mask=cu_mask, ).attention_outputs print(out[-1, -1, -1, -5:], out[-1, 0, -1, -5:]) print(vout[-1, -1, -1, -5:], vout[-1, 0, -1, -5:]) print(jnp.allclose(out, vout, atol=0.125))