easydel.layers.ops.gla

easydel.layers.ops.gla#

easydel.layers.ops.gla.ceildiv(a: int, b: int) int[source]#
easydel.layers.ops.gla.recurrent_gla(query: ~jax.Array, key: ~jax.Array, value: ~jax.Array, gk: ~jax.Array, scale: float = -1.0, initial_state: ~typing.Optional[~jax.Array] = None, chunk_size: int = 0, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, output_final_state: bool = False) Tuple[Array, Optional[Array]][source]#

Recurrent Gated Linear Attention with optional chunked sequence processing.

This function implements a recurrent variant of gated linear attention, which processes the input sequence either as a whole or in chunks. The recurrent nature allows the model to maintain and update a hidden state throughout the sequence processing.

Parameters
  • q – Query tensor of shape (B, S, H, D) where: - B is the batch size - S is the sequence length - H is the number of heads - D is the head dimension for queries and keys

  • k – Key tensor of shape (B, S, H, D)

  • v – Value tensor of shape (B, S, H, V) where: - V is the head dimension for values (can be different from D)

  • gk – Gating tensor of shape (B, S, H, D), typically log-sigmoid values

  • scale – Scaling factor for attention scores. If -1.0, uses 1/sqrt(D). Default: -1.0

  • initial_state – Optional initial hidden state of shape (B, H, D, V). If None, initializes to zeros. Default: None

  • chunk_size – Size of chunks for processing long sequences. If 0, processes the entire sequence at once. Default: 0

  • dtype – Data type for computation. Default: jnp.float32

  • output_final_state – Whether to return the final hidden state. Default: False

Returns

  • output: Tensor of shape (B, S, H, V) containing the attention output

  • final_state: If output_final_state is True, returns the final hidden state of shape (B, H, D, V). Otherwise, returns None.

Return type

A tuple (output, final_state) where

Example

>>> B, S, H, D, V = 1, 32, 4, 64, 32
>>> q = jax.random.normal(key, (B, S, H, D))
>>> k = jax.random.normal(key, (B, S, H, D))
>>> v = jax.random.normal(key, (B, S, H, V))
>>> gk = jax.nn.log_sigmoid(jax.random.normal(key, (B, S, H, D)))
>>> output, _ = recurrent_gla(q, k, v, gk, chunk_size=8)

Notes

  • The function is JIT-compiled with static arguments for chunk_size, output_final_state, and dtype.

  • For long sequences, using chunk_size > 0 can help manage memory usage by processing the sequence in smaller chunks.

  • The gating mechanism (gk) helps control the flow of information through the recurrent updates.

  • The hidden state maintains the running computation across the sequence or chunk boundaries.

Implementation Details:

The attention computation for each position i follows: 1. q_i = q[i] * scale 2. gk_i = exp(gk[i]) 3. kv_i = k[i] ⊗ v[i] # outer product 4. h = h * gk_i + kv_i # recurrent update 5. o_i = sum(q_i * h) # output computation

Memory Complexity:
  • Without chunking: O(B * S * H * max(D, V))

  • With chunking: O(B * chunk_size * H * max(D, V))

easydel.layers.ops.gla.test_recurrent_gla()[source]#