easydel.layers.ops.gla#
- easydel.layers.ops.gla.recurrent_gla(query: ~jax.Array, key: ~jax.Array, value: ~jax.Array, gk: ~jax.Array, scale: float = -1.0, initial_state: ~typing.Optional[~jax.Array] = None, chunk_size: int = 0, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, output_final_state: bool = False) Tuple[Array, Optional[Array]][source]#
Recurrent Gated Linear Attention with optional chunked sequence processing.
This function implements a recurrent variant of gated linear attention, which processes the input sequence either as a whole or in chunks. The recurrent nature allows the model to maintain and update a hidden state throughout the sequence processing.
- Parameters
q – Query tensor of shape (B, S, H, D) where: - B is the batch size - S is the sequence length - H is the number of heads - D is the head dimension for queries and keys
k – Key tensor of shape (B, S, H, D)
v – Value tensor of shape (B, S, H, V) where: - V is the head dimension for values (can be different from D)
gk – Gating tensor of shape (B, S, H, D), typically log-sigmoid values
scale – Scaling factor for attention scores. If -1.0, uses 1/sqrt(D). Default: -1.0
initial_state – Optional initial hidden state of shape (B, H, D, V). If None, initializes to zeros. Default: None
chunk_size – Size of chunks for processing long sequences. If 0, processes the entire sequence at once. Default: 0
dtype – Data type for computation. Default: jnp.float32
output_final_state – Whether to return the final hidden state. Default: False
- Returns
output: Tensor of shape (B, S, H, V) containing the attention output
final_state: If output_final_state is True, returns the final hidden state of shape (B, H, D, V). Otherwise, returns None.
- Return type
A tuple (output, final_state) where
Example
>>> B, S, H, D, V = 1, 32, 4, 64, 32 >>> q = jax.random.normal(key, (B, S, H, D)) >>> k = jax.random.normal(key, (B, S, H, D)) >>> v = jax.random.normal(key, (B, S, H, V)) >>> gk = jax.nn.log_sigmoid(jax.random.normal(key, (B, S, H, D))) >>> output, _ = recurrent_gla(q, k, v, gk, chunk_size=8)
Notes
The function is JIT-compiled with static arguments for chunk_size, output_final_state, and dtype.
For long sequences, using chunk_size > 0 can help manage memory usage by processing the sequence in smaller chunks.
The gating mechanism (gk) helps control the flow of information through the recurrent updates.
The hidden state maintains the running computation across the sequence or chunk boundaries.
- Implementation Details:
The attention computation for each position i follows: 1. q_i = q[i] * scale 2. gk_i = exp(gk[i]) 3. kv_i = k[i] ⊗ v[i] # outer product 4. h = h * gk_i + kv_i # recurrent update 5. o_i = sum(q_i * h) # output computation
- Memory Complexity:
Without chunking: O(B * S * H * max(D, V))
With chunking: O(B * chunk_size * H * max(D, V))