easydel.kernels.cpu_ops.__init__

easydel.kernels.cpu_ops.__init__#

easydel.kernels.cpu_ops.__init__.jax_flash_attention(query_state: Array, key_state: Array, value_state: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, *, dropout: float = 0.0, inference: bool = True, key: Optional[PRNGKey] = None, blocksize_q: Optional[int] = None, blocksize_k: Optional[int] = None, dtype: Optional[dtype] = None, precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset] = None, head_dim: Optional[int] = None, softmax_scale: Optional[float] = None) Array#

Computes multi-head attention using FlashAttention implementation.

This implementation makes use of the FlashAttention algorithm for faster and more memory-efficient computation of attention. It is particularly beneficial for long sequences.

Parameters
  • query_state – Query, shape (batch_size, q_len, num_heads, head_dim).

  • key_state – Key, shape (batch_size, kv_len, num_heads, head_dim).

  • value_state – Value, shape (batch_size, kv_len, num_heads, head_dim).

  • mask –

    tp.Optional attention mask. This can be any of the following:

    • No mask (default): All attention weights are computed.

    • Boolean mask (2D): shape (batch_size, q_len), with True for

      valid and False for masked positions.

    • Integer mask (2D): shape (batch_size, q_len), where the value at

      each position indicates the length of the sequence to attend to.

    • 4D mask: shape (batch_size, q_len, kv_len), with True for

      valid and False for masked positions.

  • bias – tp.Optional attention bias.

  • dropout – Dropout rate.

  • inference – Whether to run in inference mode.

  • key – PRNG key for dropout.

  • blocksize_q – Block size for query processing.

  • blocksize_k – Block size for key/value processing.

  • dtype – tp.Optional dtype for the output.

  • precision – tp.Optional precision for matrix multiplication.

  • head_dim – tp.Optional head dim to be used at query_state = query_state / math.sqrt(float(head_dim or query_state.shape[-1])).

  • softmax_scale (softmax_scale tp.Optional softmax_scale to be used for query_state = query_state *) –

Returns

Output of multi-head attention, with shape (batch_size, q_len, num_heads, head_dim).

Raises

ValueError – If dropout is not in the range [0, 1], or if key is not provided during training when dropout > 0.