Source code for easydel.layers.ops._gla

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import typing as tp
from functools import partial

import jax
import jax.numpy as jnp
from jax import lax
from ._base_operation import BaseOperation


[docs]def ceildiv(a: int, b: int) -> int: return -(a // -b)
[docs]@partial( jax.jit, static_argnames=("chunk_size", "output_final_state", "dtype", "scale"), ) def recurrent_gla( query: jnp.ndarray, # shape: (B, S, H, D) key: jnp.ndarray, # shape: (B, S, H, D) value: jnp.ndarray, # shape: (B, S, H, V) gk: jnp.ndarray, # shape: (B, S, H, D) scale: float = -1.0, initial_state: tp.Optional[jnp.ndarray] = None, # shape: (B, H, D, V) chunk_size: int = 0, # if > 0, process sequence in chunks dtype: jnp.dtype = jnp.float32, output_final_state: bool = False, ) -> tp.Tuple[jnp.ndarray, tp.Optional[jnp.ndarray]]: """Recurrent Gated Linear Attention with optional chunked sequence processing. This function implements a recurrent variant of gated linear attention, which processes the input sequence either as a whole or in chunks. The recurrent nature allows the model to maintain and update a hidden state throughout the sequence processing. Args: q: Query tensor of shape (B, S, H, D) where: - B is the batch size - S is the sequence length - H is the number of heads - D is the head dimension for queries and keys k: Key tensor of shape (B, S, H, D) v: Value tensor of shape (B, S, H, V) where: - V is the head dimension for values (can be different from D) gk: Gating tensor of shape (B, S, H, D), typically log-sigmoid values scale: Scaling factor for attention scores. If -1.0, uses 1/sqrt(D). Default: -1.0 initial_state: Optional initial hidden state of shape (B, H, D, V). If None, initializes to zeros. Default: None chunk_size: Size of chunks for processing long sequences. If 0, processes the entire sequence at once. Default: 0 dtype: Data type for computation. Default: jnp.float32 output_final_state: Whether to return the final hidden state. Default: False Returns: A tuple (output, final_state) where: - output: Tensor of shape (B, S, H, V) containing the attention output - final_state: If output_final_state is True, returns the final hidden state of shape (B, H, D, V). Otherwise, returns None. Example: >>> B, S, H, D, V = 1, 32, 4, 64, 32 >>> q = jax.random.normal(key, (B, S, H, D)) >>> k = jax.random.normal(key, (B, S, H, D)) >>> v = jax.random.normal(key, (B, S, H, V)) >>> gk = jax.nn.log_sigmoid(jax.random.normal(key, (B, S, H, D))) >>> output, _ = recurrent_gla(q, k, v, gk, chunk_size=8) Notes: - The function is JIT-compiled with static arguments for chunk_size, output_final_state, and dtype. - For long sequences, using chunk_size > 0 can help manage memory usage by processing the sequence in smaller chunks. - The gating mechanism (gk) helps control the flow of information through the recurrent updates. - The hidden state maintains the running computation across the sequence or chunk boundaries. Implementation Details: The attention computation for each position i follows: 1. q_i = q[i] * scale 2. gk_i = exp(gk[i]) 3. kv_i = k[i] ⊗ v[i] # outer product 4. h = h * gk_i + kv_i # recurrent update 5. o_i = sum(q_i * h) # output computation Memory Complexity: - Without chunking: O(B * S * H * max(D, V)) - With chunking: O(B * chunk_size * H * max(D, V)) """ B, S, H, D = query.shape V = value.shape[-1] query = query.astype(dtype) key = key.astype(dtype) value = value.astype(dtype) gk = gk.astype(dtype) # Set scale if not provided if scale == -1.0: scale = D**-0.5 # Initialize hidden state h = jnp.zeros((B, H, D, V), dtype=dtype) if initial_state is not None: h = h + initial_state.astype(dtype) def process_chunk( h: jnp.ndarray, chunk_idx: int, chunk_size: int, ) -> tp.Tuple[jnp.ndarray, jnp.ndarray]: start_idx = chunk_idx * chunk_size end_idx = min(start_idx + chunk_size, S) # Get chunk tensors q_chunk = query[:, start_idx:end_idx] k_chunk = key[:, start_idx:end_idx] v_chunk = value[:, start_idx:end_idx] gk_chunk = gk[:, start_idx:end_idx] def scan_fn(carry, x): (h,) = carry q_i, k_i, v_i, gk_i = x q_i = q_i * scale # (B, H, D) gk_i = jnp.exp(gk_i) # (B, H, D) kv_i = k_i[..., None] * v_i[..., None, :] # (B, H, D, V) h = h * gk_i[..., None] + kv_i # (B, H, D, V) o_i = jnp.sum(q_i[..., None] * h, axis=-2) # (B, H, V) return (h,), o_i scan_inputs = (q_chunk, k_chunk, v_chunk, gk_chunk) (h_next,), o_chunk = lax.scan( scan_fn, (h,), [x.transpose(1, 0, 2, 3) for x in scan_inputs], ) o_chunk = o_chunk.transpose(1, 0, 2, 3) return h_next, o_chunk if chunk_size > 0: num_chunks = ceildiv(S, chunk_size) o_chunks = [] for chunk_idx in range(num_chunks): h, o_chunk = process_chunk(h, chunk_idx, chunk_size) o_chunks.append(o_chunk) o = jnp.concatenate(o_chunks, axis=1) else: def scan_fn(carry, x): (h,) = carry q_i, k_i, v_i, gk_i = x q_i = q_i * scale gk_i = jnp.exp(gk_i) kv_i = k_i[..., None] * v_i[..., None, :] h = h * gk_i[..., None] + kv_i o_i = jnp.sum(q_i[..., None] * h, axis=-2) return (h,), o_i scan_inputs = (query, key, value, gk) (h,), o = lax.scan( scan_fn, (h,), [x.transpose(1, 0, 2, 3) for x in scan_inputs], ) o = o.transpose(1, 0, 2, 3) if not output_final_state: h = None return o, h
[docs]class RecurrentGLA(BaseOperation):
[docs] def forward_native( self, query: jnp.ndarray, # shape: (B, S, H, D) key: jnp.ndarray, # shape: (B, S, H, D) value: jnp.ndarray, # shape: (B, S, H, V) gk: jnp.ndarray, # shape: (B, S, H, D) scale: float = -1.0, initial_state: tp.Optional[jnp.ndarray] = None, # shape: (B, H, D, V) chunk_size: int = 0, # if > 0, process sequence in chunks dtype: jnp.dtype = jnp.float32, output_final_state: bool = False, ) -> tp.Tuple[jnp.ndarray, tp.Optional[jnp.ndarray]]: return recurrent_gla( query=query, key=key, value=value, gk=gk, scale=scale, initial_state=initial_state, chunk_size=chunk_size, dtype=dtype, output_final_state=output_final_state, )
[docs]def test_recurrent_gla(): # Test dimensions B, S, H, D, V = 1, 32, 4, 64, 32 # Generate random inputs rng = jax.random.PRNGKey(0) keys = jax.random.split(rng, 5) query = jax.random.normal(keys[0], (B, S, H, D)) key = jax.random.normal(keys[1], (B, S, H, D)) value = jax.random.normal(keys[2], (B, S, H, V)) gk = jax.nn.log_sigmoid(jax.random.normal(keys[3], (B, S, H, D))) h0 = jax.random.normal(keys[4], (B, H, D, V)) # Test case 1: Without chunking o1, h1 = RecurrentGLA()( query, key, value, gk, initial_state=h0, output_final_state=True, ) # Test case 2: With chunking chunk_size = 8 o2, h2 = recurrent_gla( query, key, value, gk, initial_state=h0, chunk_size=chunk_size, output_final_state=True, ) # Check shapes assert o1.shape == (B, S, H, V) assert o2.shape == (B, S, H, V) assert h1.shape == (B, H, D, V) assert h2.shape == (B, H, D, V) # Check that results are close assert jnp.allclose(o1, o2, rtol=1e-5, atol=1e-5) assert jnp.allclose(h1, h2, rtol=1e-5, atol=1e-5) print("All tests passed!")
if __name__ == "__main__": test_recurrent_gla()