easydel.layers.attention_operator.modules.scaled_dot_product#

class easydel.layers.attention_operator.modules.scaled_dot_product.ScaledDotProductAttn(metadata: AttentionMetadata)[source]#

Bases: AttentionImpl

An attention implementation that leverages jax.nn.dot_product_attention.

This class utilizes JAX’s optimized SDPA primitive, which can dispatch to different backend implementations (like XLA, cuDNN, or potentially Flash Attention emulation on CUDA depending on JAX version and hardware).

It handles sharding using shard_map and manages backend-specific dispatch (primarily distinguishing between CUDA/GPU and other backends like TPU/CPU).

Registered under the names “sdpa”, “cudnn”, and “cuda_flash_attn2”.

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass. Delegates to forward_native (XLA implementation).

forward_cuda(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#

Computes attention using jax.nn.dot_product_attention with the “cudnn” implementation.

This is optimized for NVIDIA GPUs using cuDNN. It applies sharding via shard_map. Note: The cuDNN implementation might have specific requirements (e.g., dtype). Causal masking is disabled during generation mode (q_len=1) as it’s unnecessary.

Parameters
  • q – Query tensor (B, T, H, D).

  • k – Key tensor (B, S, H_kv, D).

  • v – Value tensor (B, S, H_kv, D_v).

  • mask – Optional boolean attention mask (broadcastable to B, 1, T, S).

  • bias – Optional attention bias tensor (broadcastable to B, H, T, S).

  • init_bias – Optional callable to initialize bias if mask/bias are None.

  • causal – If True, applies causal masking within the primitive, unless in generation mode.

  • **ignore – Ignored keyword arguments.

Returns

An AttentionOutput object. Weights are not returned.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Delegates to the CUDA-specific implementation.

forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#

Computes attention using jax.nn.dot_product_attention with the “xla” implementation.

This is typically used for CPU and TPU backends. It applies sharding via shard_map.

Parameters
  • q – Query tensor (B, T, H, D).

  • k – Key tensor (B, S, H_kv, D).

  • v – Value tensor (B, S, H_kv, D_v).

  • mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Passed directly to the primitive.

  • bias – Optional attention bias tensor (broadcastable to B, H, T, S). Passed directly to the primitive. If bias is provided, causal is forced to False.

  • init_bias – Optional callable to initialize bias if mask/bias are None.

  • causal – If True and bias is None, applies causal masking within the primitive.

  • **ignore – Ignored keyword arguments.

Returns

An AttentionOutput object. Note that jax.nn.dot_product_attention typically does not return attention weights.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass. Currently delegates to forward_native.

forward_tpu(*args, **kwargs) AttentionOutput[source]#

TPU forward pass. Delegates to forward_native (XLA implementation).

get_impl_metadata() AttentionMetadata[source]#

Returns the metadata associated with this attention implementation instance.

Returns

The AttentionMetadata provided during initialization.

classmethod get_impl_name() Union[str, Tuple[str]][source]#

Returns the registered name(s) for this implementation.

Returns

(“sdpa”, “cudnn”, “cuda_flash_attn2”).

Return type

A tuple of strings