easydel.kernels.cpu_ops.__init__#
- easydel.kernels.cpu_ops.__init__.jax_flash_attention(query_state: Array, key_state: Array, value_state: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, *, dropout: float = 0.0, inference: bool = True, key: Optional[PRNGKey] = None, blocksize_q: Optional[int] = None, blocksize_k: Optional[int] = None, dtype: Optional[dtype] = None, precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset] = None, head_dim: Optional[int] = None, softmax_scale: Optional[float] = None) Array#
Computes multi-head attention using FlashAttention implementation.
This implementation makes use of the FlashAttention algorithm for faster and more memory-efficient computation of attention. It is particularly beneficial for long sequences.
- Parameters
query_state – Query, shape (batch_size, q_len, num_heads, head_dim).
key_state – Key, shape (batch_size, kv_len, num_heads, head_dim).
value_state – Value, shape (batch_size, kv_len, num_heads, head_dim).
mask –
tp.Optional attention mask. This can be any of the following:
No mask (default): All attention weights are computed.
- Boolean mask (2D): shape (batch_size, q_len), with True for
valid and False for masked positions.
- Integer mask (2D): shape (batch_size, q_len), where the value at
each position indicates the length of the sequence to attend to.
- 4D mask: shape (batch_size, q_len, kv_len), with True for
valid and False for masked positions.
bias – tp.Optional attention bias.
dropout – Dropout rate.
inference – Whether to run in inference mode.
key – PRNG key for dropout.
blocksize_q – Block size for query processing.
blocksize_k – Block size for key/value processing.
dtype – tp.Optional dtype for the output.
precision – tp.Optional precision for matrix multiplication.
head_dim – tp.Optional head dim to be used at query_state = query_state / math.sqrt(float(head_dim or query_state.shape[-1])).
softmax_scale (softmax_scale tp.Optional softmax_scale to be used for query_state = query_state *) –
- Returns
Output of multi-head attention, with shape (batch_size, q_len, num_heads, head_dim).
- Raises
ValueError – If dropout is not in the range [0, 1], or if key is not provided during training when dropout > 0.