easydel.layers.attention_operator.__init__#
- class easydel.layers.attention_operator.__init__.AttentionImpl(metadata: AttentionMetadata)[source]#
Bases:
BaseOperationAbstract Base Class for specific attention implementations.
Inherits from BaseOperation to leverage backend-specific dispatching. Subclasses must implement the core attention logic (forward_native) and potentially provide optimized versions for TPU (forward_tpu), GPU (forward_gpu), etc. They also need to declare their name and associated metadata.
Provides common helper methods for attention processing like mask manipulation, head repeating (for GQA/MQA), and determining runtime mode.
- create_stable_sharding(state_ps: Optional[PartitionSpec] = None, preserved_indices: List[int] = None, clone_ps: Optional[PartitionSpec] = None, dep: Optional[Union[PartitionSpec, bool]] = True, tensor: Optional[Array] = None) Optional[PartitionSpec][source]#
Helper to create a PartitionSpec, potentially preserving only certain axes.
This might be used for ensuring intermediate tensors or states have compatible sharding, possibly replicating across axes not specified in preserved_indices.
- Parameters
state_ps – The base PartitionSpec to modify.
preserved_indices – A list of dimension indices whose partitioning should be kept from state_ps (or clone_ps if provided). Other dimensions will be set to None (replicated). If None, state_ps is returned.
clone_ps – An optional PartitionSpec to copy axis names from for the preserved indices, instead of using state_ps.
dep – A dependency flag or PartitionSpec. If None, returns None. Defaults to True. (The exact purpose might be context-specific, potentially for control flow).
- Returns
A new PartitionSpec with only specified axes partitioned, or None based on dep. Returns state_ps directly if preserved_indices is None.
- current_backend() Literal['tpu', 'gpu', 'cpu'][source]#
Returns the current JAX default backend as a lowercase string literal.
- Returns
“tpu”, “gpu”, or “cpu”.
- abstract get_impl_metadata() AttentionMetadata[source]#
Returns the AttentionMetadata associated with this implementation instance.
- Returns
The AttentionMetadata instance passed during initialization.
- abstract classmethod get_impl_name() Union[str, Tuple[str, ...]][source]#
Returns the unique name(s) identifying this attention implementation.
Used by the AttentionRegistry. Can return a single string or a tuple/list of strings if the implementation has multiple aliases.
- Returns
A string or tuple/list of strings representing the implementation name(s).
- get_mode(q: Array, BTHD: bool = True) Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'][source]#
Determines the runtime mode (normal or generation) based on query shape.
Assumes generation mode if the query sequence length dimension is 1.
- Parameters
q – The query tensor.
BTHD – Boolean indicating tensor layout (True for B, T, H, D; False for B, H, T, D).
- static repeat_kv_heads(k: Array, v: Array, num_reps: int) Tuple[Array, Array][source]#
Repeats Key and Value heads for Grouped Query Attention (GQA) or Multi-Query Attention (MQA).
Expands the head dimension of K and V tensors to match the number of query heads.
- Parameters
k – Key tensor, assumes shape (batch, seq_len, num_kv_heads, head_dim).
v – Value tensor, assumes shape (batch, seq_len, num_kv_heads, head_dim).
num_reps – The number of times to repeat each KV head (num_q_heads // num_kv_heads).
- Returns
A tuple (k_repeated, v_repeated) with shapes (batch, seq_len, num_q_heads, head_dim).
- class easydel.layers.attention_operator.__init__.AttentionMetadata(runtime_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType], runtime_softmax_dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = None, sequence_axis_name: str = <eformer.common_types._Empty object>, mesh: ~typing.Optional[~jax._src.mesh.Mesh] = <eformer.common_types._Empty object>, platform: ~easydel.infra.etils.EasyDeLPlatforms = <eformer.common_types._Empty object>, backend: ~easydel.infra.etils.EasyDeLBackends = <eformer.common_types._Empty object>, partition_axis: ~eformer.escale.partition.manager.PartitionAxis = <eformer.common_types._Empty object>, partition_manager: ~eformer.escale.partition.manager.PartitionManager = <eformer.common_types._Empty object>, base_config: ~typing.Optional[~easydel.infra.base_config.EasyDeLBaseConfig] = None, scan_ring_attention: bool = <eformer.common_types._Empty object>, softmax_scale: float = <eformer.common_types._Empty object>, dropout_prob: float = <eformer.common_types._Empty object>, blocksize_q: int = <eformer.common_types._Empty object>, blocksize_k: int = <eformer.common_types._Empty object>, blocksize_b: int = <eformer.common_types._Empty object>)[source]#
Bases:
objectHolds configuration, context, and metadata for attention operations.
This class centralizes various parameters needed by different attention implementations, facilitating consistent behavior and configuration. It handles default values and can be initialized from an EasyDeLBaseConfig.
- runtime_dtype#
The primary JAX dtype for computations (e.g., q, k, v).
- Type
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]
- runtime_softmax_dtype#
Optional JAX dtype for the softmax computation, allowing for higher precision if needed (e.g., float32).
- Type
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]]
- sequence_axis_name#
The name used for the sequence axis in JAX parallelism (axis_names for pjit).
- Type
str
- mesh#
The JAX device mesh for distributed computation. Must be provided or inferred from context.
- Type
Optional[jax._src.mesh.Mesh]
- platform#
The target hardware platform (e.g., TPU, GPU).
- backend#
The specific JAX backend being used (e.g., TPU, CUDA, ROCM).
- partition_axis#
Configuration for partitioning axes in distributed settings. (Likely from eformer.escale).
- base_config#
An optional reference to the base model configuration object for sourcing default values.
- Type
- scan_ring_attention#
Boolean flag indicating whether to use ring attention via jax.lax.scan.
- Type
bool
- softmax_scale#
The scaling factor applied before the softmax operation. Often 1 / sqrt(head_dim).
- Type
float
- dropout_prob#
The dropout probability applied to attention weights.
- Type
float
- blocksize_q#
Block size for the query sequence dimension in blockwise attention.
- Type
int
- blocksize_k#
Block size for the key/value sequence dimension in blockwise attention.
- Type
int
- blocksize_b#
Block size for the batch dimension in blockwise attention (often 1).
- Type
int
- backend: EasyDeLBackends = <eformer.common_types._Empty object>#
- base_config: Optional[EasyDeLBaseConfig] = None#
- blocksize_b: int = <eformer.common_types._Empty object>#
- blocksize_k: int = <eformer.common_types._Empty object>#
- blocksize_q: int = <eformer.common_types._Empty object>#
- dropout_prob: float = <eformer.common_types._Empty object>#
- classmethod from_config(config: EasyDeLBaseConfig, softmax_scale: float, dropout_prob: float = 0.0) AttentionMetadata[source]#
Factory method to create AttentionMetadata from an EasyDeLBaseConfig.
- Parameters
config – The base configuration object (e.g., model config).
softmax_scale – The attention softmax scaling factor. Usually calculated based on head dimension.
dropout_prob – The attention dropout probability. Defaults to 0.0.
- Returns
An initialized AttentionMetadata instance.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- get_shardings(mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], BTHD: bool = True, qkv_mni_sharding: bool = False)[source]#
Generates JAX PartitionSpecs for attention tensors based on runtime mode.
- Parameters
mode – The current runtime mode (normal or generation).
BTHD – Boolean indicating tensor layout. True for (Batch, Time, Head, Dim), False for (Batch, Head, Time, Dim).
- Returns
(query, key, value, bias, mask, attention_output)
- Return type
A tuple containing PartitionSpecs for
- partition_axis: PartitionAxis = <eformer.common_types._Empty object>#
- partition_manager: PartitionManager = <eformer.common_types._Empty object>#
- platform: EasyDeLPlatforms = <eformer.common_types._Empty object>#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- scan_ring_attention: bool = <eformer.common_types._Empty object>#
- sequence_axis_name: str = <eformer.common_types._Empty object>#
- set_attrs_carefully(attr_name: str, default: Optional[Any], pickup_name: Optional[str] = None, use_base_config: bool = True)[source]#
Internal helper to set an attribute if it’s not already set (or is Ellipsis).
Optionally retrieves the value from self.base_config using pickup_name (or attr_name if pickup_name is None).
- Parameters
attr_name – The name of the attribute to set on self.
default – The default value to use if not found in base_config or if use_base_config is False.
pickup_name – The name of the attribute to look for in base_config. Defaults to attr_name.
use_base_config – Whether to attempt retrieving the value from base_config.
- softmax_scale: float = <eformer.common_types._Empty object>#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.attention_operator.__init__.AttentionOutput(attention_weights: Optional[Array] = None, attention_outputs: Optional[Array] = None)[source]#
Bases:
objectContainer for the outputs of an attention operation.
- attention_weights#
The attention probabilities, typically of shape (batch, num_heads, query_seq_len, key_value_seq_len). Optional.
- Type
Optional[jax.Array]
- attention_outputs#
The final weighted sum of values, typically of shape (batch, query_seq_len, num_heads, head_dim) or (batch, num_heads, query_seq_len, head_dim). Optional.
- Type
Optional[jax.Array]
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.attention_operator.__init__.AttentionRegistry[source]#
Bases:
objectRegistry for discovering and managing different AttentionImpl classes.
Allows registering implementations using a decorator and retrieving or instantiating them by name.
- classmethod create(impl_name: str, metadata: AttentionMetadata) AttentionImpl[source]#
Creates an instance of an attention implementation by name.
Retrieves the class associated with impl_name and initializes it with the provided metadata.
- Parameters
impl_name – The name of the implementation to instantiate.
metadata – The AttentionMetadata to pass to the implementation’s constructor.
- Returns
An initialized instance of the requested AttentionImpl subclass.
- Raises
ValueError – If no implementation is registered with impl_name.
- classmethod get(impl_name: str) Type[AttentionImpl][source]#
Retrieves an attention implementation class by its registered name.
- Parameters
impl_name – The name of the implementation to retrieve.
- Returns
The AttentionImpl subclass registered under the given name.
- Raises
ValueError – If no implementation is registered with that name.
- classmethod list_implementations() List[str][source]#
Returns a list of names of all registered attention implementations.
- Returns
A list of strings, where each string is a registered implementation name.
- classmethod register(impl_cls: Type[ICa]) Type[ICa][source]#
Class method decorator to register an AttentionImpl subclass.
The implementation is registered under the name(s) returned by its get_impl_name() class method.
Example: ```python @AttentionRegistry.register class FlashAttentionImpl(AttentionImpl):
@classmethod def get_impl_name(cls) -> str:
return “flash”
# … implementation …
- Parameters
impl_cls – The AttentionImpl subclass to register.
- Returns
The registered class itself.
- class easydel.layers.attention_operator.__init__.AutoRegressiveDecodeAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAttention implementation tailored for the autoregressive decoding step.
This class handles the attention mechanism when generating tokens one by one, attending to the previously generated sequence stored in a cache. It utilizes shard_map for distributed computation and supports different backends, including a potential Pallas-optimized version for TPUs. It assumes the query sequence length is 1.
- metadata#
Configuration metadata for the attention mechanism.
- Type
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass for autoregressive decoding attention.
Delegates to the native JAX/XLA implementation (forward_native).
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass for autoregressive decoding attention.
Delegates to the native JAX/XLA implementation (forward_native). Future optimizations might add CUDA-specific kernels here.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass for autoregressive decoding attention.
Currently delegates to forward_cuda.
- forward_native(q: Array, k: Array, v: Array, cache_view: TransformerCacheView, **ignores) AttentionOutput[source]#
Performs the native JAX/XLA forward pass for autoregressive decoding attention.
This implementation uses shard_map to distribute the computation and relies on standard JAX operations (einsum, softmax). It calculates attention weights between the single query token and the keys in the cache, applies masking based on the valid range defined in cache_view, computes the softmax, and finally computes the weighted sum of values.
- Parameters
q (Array) – Query tensor of shape (batch_size, 1, num_query_heads, head_dim).
k (Array) – Key tensor (from cache) of shape (batch_size, kv_sequence_length, num_kv_heads, head_dim).
v (Array) – Value tensor (from cache) of shape (batch_size, kv_sequence_length, num_kv_heads, head_dim).
cache_view (TransformerCacheView) – Contains metadata about the cache, specifically starts (start index for valid keys/values) and index (current index or length).
**ignores – Ignored keyword arguments.
- Returns
- An object containing the attention outputs (attention_outputs)
of shape (batch_size, 1, num_query_heads, head_dim). Attention weights are not returned.
- Return type
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass for autoregressive decoding attention.
Delegates to the native JAX/XLA implementation (forward_native). Future optimizations might add ROCm-specific kernels here.
- forward_tpu(*args, **kwargs) AttentionOutput[source]#
TPU forward pass for autoregressive decoding attention.
Currently delegates to the native JAX/XLA implementation (forward_native). Consider using _forward_tpu for potential Pallas optimization if available/enabled.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.
- class easydel.layers.attention_operator.__init__.FlashAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAn implementation of Flash Attention V2 using specialized JAX primitives.
This class leverages jax.experimental.pallas.ops.tpu.flash_attention for TPUs and a Triton kernel (triton_flash_attention) for GPUs (CUDA). It is registered under the name “flash_attn2”. CPU execution is not supported and will raise an error.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native, which raises an error.
- Raises
NotImplementedError – Via forward_native.
- forward_cuda(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Performs Flash Attention V2 on GPU (CUDA) using a Triton kernel.
Handles optional mask/bias, KV head repetition (inside the kernel or before), and sharding based on metadata. Assumes triton_flash_attention handles KV head repetition internally if needed or expects broadcastable KV. Assumes BTHD input/output format for the Triton kernel.
- Parameters
q – Query tensor, expected shape (batch, q_seq_len, num_q_heads, head_dim).
k – Key tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
v – Value tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
mask – Optional boolean attention mask. Used by the kernel if bias is not provided. Shape typically (batch, 1, q_seq_len, kv_seq_len) or broadcastable.
bias – Optional attention bias tensor. Added by the kernel. Shape typically (batch, num_heads, q_seq_len, kv_seq_len) or broadcastable. Takes precedence over mask within the kernel logic if both are somehow passed.
init_bias – Optional callable function to initialize bias if mask and bias are None.
causal – If True, instructs the kernel to apply a causal mask.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention outputs. Attention weights are typically not computed or returned by Flash Attention implementations.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Delegates to the CUDA-specific implementation.
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Native (CPU) forward pass for Flash Attention. Not implemented.
- Raises
NotImplementedError – Flash Attention is not supported on CPU via this implementation.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Not implemented, falls back to native (error).
- Raises
NotImplementedError – Via forward_native. ROCm support requires a specific kernel.
- forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Performs Flash Attention V2 on TPU using jax.experimental.pallas.ops.tpu.flash_attention.
Handles optional mask/bias, KV head repetition, and sharding based on metadata. Note: The Pallas implementation expects inputs in BHTD format.
- Parameters
q – Query tensor, expected shape (batch, q_seq_len, num_q_heads, head_dim).
k – Key tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
v – Value tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
mask – Optional boolean attention mask. Bias will be generated from this if provided. Shape typically (batch, 1, q_seq_len, kv_seq_len) or broadcastable.
bias – Optional attention bias tensor. Added directly to attention scores. Shape typically (batch, num_heads, q_seq_len, kv_seq_len) or broadcastable. Takes precedence over mask.
init_bias – Optional callable function to initialize bias if mask and bias are None.
causal – If True, applies a causal mask. Ignored if q_seq_len is 1 (generation).
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention outputs. Attention weights are typically not computed or returned by Flash Attention implementations.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.
- class easydel.layers.attention_operator.__init__.PagedAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAttention implementation using the Paged Attention mechanism with TPU Pallas kernels.
This class provides an attention mechanism suitable for scenarios where the Key-Value cache is managed in non-contiguous pages (Paged KV Cache). It leverages specialized kernels for efficient execution on TPUs, handling prefill and decode phases separately or in a mixed mode.
- metadata#
Configuration metadata for the attention mechanism. While this class uses AttentionMetadata, it primarily relies on the additional PagedAttentionMetadata passed during the forward call for paged-specific information.
- Type
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass.
- Raises
NotImplementedError – Paged Attention currently relies on Pallas for TPUs and does not have a specific CUDA implementation. (Future work might add this).
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass.
- Raises
NotImplementedError – Paged Attention currently relies on Pallas for TPUs and does not have a specific ROCm implementation.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
Generic GPU forward pass.
- Raises
NotImplementedError – Paged Attention relies on specific kernels (currently Pallas for TPU) and does not have a generic GPU implementation.
- forward_native(*args, **kwargs) AttentionOutput[source]#
Native (CPU) forward pass.
- Raises
NotImplementedError – Paged Attention requires specialized kernels and does not have a native CPU implementation.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Not implemented for Paged Attention.
- forward_tpu(q: Array, k: Array, v: Array, cache_view: PagedAttentionCacheView, cache_metadata: PagedAttentionMetadata, **ignore) AttentionOutput[source]#
TPU forward pass for Paged Attention.
Determines the execution mode (prefill, decode, or mixed) based on the provided cache_metadata and dispatches the computation to the corresponding internal TPU method (_prefill_tpu, _decode_tpu, _mixed_tpu).
- Parameters
q (Array) – Query tensor. Shape depends on mode (prefill/decode/mixed).
k (Array) – Key tensor (ignored).
v (Array) – Value tensor (ignored).
cache_view (PagedAttentionCacheView) – Contains the paged KV cache tensors.
cache_metadata (PagedAttentionMetadata) – Contains metadata describing the state and mode (prefill/decode/mixed) of the current batch.
**ignore – Ignored keyword arguments.
- Returns
- An object containing the computed attention outputs.
Attention weights are typically not computed or returned in paged attention.
- Return type
- get_impl_metadata() AttentionMetadata[source]#
Retrieves the metadata associated with this attention implementation instance.
- Returns
The metadata object provided during initialization.
- Return type
- class easydel.layers.attention_operator.__init__.RingAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAttention implementation using ring-passing algorithm or blockwise scan.
This implementation supports: - Native (scan-based) blockwise attention via blockwise_attn. - TPU-specific ring attention using pallas_ring_attention kernel.
It is registered under the name “ring”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native (scan-based).
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass. Currently delegates to forward_native (scan-based).
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Currently delegates to forward_native (scan-based).
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: Optional[PRNGKey] = None, causal: bool = True, **ignore) AttentionOutput[source]#
Computes attention using the scan-based blockwise_attn function.
Handles optional mask/bias, KV head repetition, and sharding constraints.
- Parameters
q – Query tensor (B, T, H, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S).
bias – Optional attention bias (broadcastable to B, H, T, S).
init_bias – Optional callable to initialize bias if mask/bias are None.
deterministic – If False, enables dropout. Requires dropout_rng.
dropout_rng – JAX PRNG key for dropout if deterministic is False.
causal – Apply causal mask if True.
**ignore – Ignored keyword arguments.
- Returns
AttentionOutput containing the attention result.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Currently delegates to forward_native (scan-based).
- forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: Optional[PRNGKey] = None, causal: bool = True, **ignore) AttentionOutput[source]#
Computes Ring Attention on TPU using the pallas_ring_attention kernel.
Handles optional mask/bias, sharding, and passes configuration to the kernel.
- Parameters
q – Query tensor (B, T, H, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S).
bias – Optional attention bias (broadcastable to B, H, T, S).
init_bias – Optional callable to initialize bias if mask/bias are None.
deterministic – If False, potentially enables dropout within the kernel (if supported).
dropout_rng – JAX PRNG key (may be used by the kernel if dropout is enabled).
causal – Apply causal mask if True. Passed to the kernel.
**ignore – Ignored keyword arguments.
- Returns
AttentionOutput containing the attention result.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this instance.
- class easydel.layers.attention_operator.__init__.ScaledDotProductAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAn attention implementation that leverages jax.nn.dot_product_attention.
This class utilizes JAX’s optimized SDPA primitive, which can dispatch to different backend implementations (like XLA, cuDNN, or potentially Flash Attention emulation on CUDA depending on JAX version and hardware).
It handles sharding using shard_map and manages backend-specific dispatch (primarily distinguishing between CUDA/GPU and other backends like TPU/CPU).
Registered under the names “sdpa”, “cudnn”, and “cuda_flash_attn2”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native (XLA implementation).
- forward_cuda(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Computes attention using jax.nn.dot_product_attention with the “cudnn” implementation.
This is optimized for NVIDIA GPUs using cuDNN. It applies sharding via shard_map. Note: The cuDNN implementation might have specific requirements (e.g., dtype). Causal masking is disabled during generation mode (q_len=1) as it’s unnecessary.
- Parameters
q – Query tensor (B, T, H, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D_v).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S).
bias – Optional attention bias tensor (broadcastable to B, H, T, S).
init_bias – Optional callable to initialize bias if mask/bias are None.
causal – If True, applies causal masking within the primitive, unless in generation mode.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object. Weights are not returned.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Delegates to the CUDA-specific implementation.
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Computes attention using jax.nn.dot_product_attention with the “xla” implementation.
This is typically used for CPU and TPU backends. It applies sharding via shard_map.
- Parameters
q – Query tensor (B, T, H, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D_v).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Passed directly to the primitive.
bias – Optional attention bias tensor (broadcastable to B, H, T, S). Passed directly to the primitive. If bias is provided, causal is forced to False.
init_bias – Optional callable to initialize bias if mask/bias are None.
causal – If True and bias is None, applies causal masking within the primitive.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object. Note that jax.nn.dot_product_attention typically does not return attention weights.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Currently delegates to forward_native.
- forward_tpu(*args, **kwargs) AttentionOutput[source]#
TPU forward pass. Delegates to forward_native (XLA implementation).
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.
- class easydel.layers.attention_operator.__init__.SplashAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAn attention implementation using the Pallas Splash Attention kernel for TPUs.
Splash Attention is an optimized attention mechanism designed for TPUs. This implementation provides a wrapper around the make_splash_mqa_single_device primitive.
Note
This implementation is primarily intended for TPUs.
- It falls back to VanillaAttn under certain conditions:
Query sequence length is 1 (generation mode).
causal is False.
Query sequence length is not divisible by 128 (kernel constraint).
Non-TPU forward methods (forward_native, forward_gpu, etc.) are not implemented and will raise NotImplementedError.
Registered under the name “splash”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Not implemented for Splash Attention.
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass. Not implemented for Splash Attention.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Not implemented for Splash Attention.
- forward_native(*args, **kwargs) AttentionOutput[source]#
Native (CPU) forward pass. Not implemented for Splash Attention.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Not implemented for Splash Attention.
- forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, causal: bool = True, cache_view: Optional[TransformerCacheView] = None, **ignore) AttentionOutput[source]#
Performs Splash Attention on TPU using the Pallas kernel.
Handles fallback logic, mask processing, block size configuration, and sharding via shard_map. Expects inputs potentially in BTHD format and transposes them to BHTD for the kernel.
- Parameters
q – Query tensor (B, T, Hq, D).
k – Key tensor (B, S, Hkv, D).
v – Value tensor (B, S, Hkv, Dv).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Used to generate segment IDs if provided.
causal – If True, applies causal masking via the kernel’s mask configuration. If False, falls back to VanillaAttn.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention outputs. Attention weights are not computed or returned by Splash Attention.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.
- class easydel.layers.attention_operator.__init__.VanillaAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplA standard, non-optimized implementation of multi-head attention.
This implementation uses basic JAX operations like jnp.einsum and standard softmax. It serves as a reference implementation and a fallback for platforms where optimized kernels (like Flash Attention) are not available or desired. It supports features like attention bias, masking, dropout, and Grouped Query Attention (GQA)/Multi-Query Attention (MQA) via reshaping.
Registered under the name “vanilla”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native.
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass. Delegates to forward_native.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Delegates to forward_native.
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = True, dropout_rng: Optional[PRNGKey] = None, **ignore) AttentionOutput[source]#
Computes multi-head attention using standard JAX operations.
Supports GQA/MQA by reshaping the query tensor to match the number of key/value heads. Applies scaling, optional bias/mask, softmax (potentially in float32), and optional dropout.
- Parameters
q – Query tensor (B, T, H_q, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D_v).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Used if bias is not provided.
bias – Optional attention bias tensor (broadcastable to B, H_q, T, S). Takes precedence over mask.
init_bias – Optional callable to initialize bias if mask/bias are None.
deterministic – If True, disables dropout.
dropout_rng – JAX PRNG key for dropout. Required if deterministic is False and dropout_prob > 0.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention weights (if computed) and the final attention outputs.
- Raises
NotImplementedError – If the bias head dimension cannot be reshaped correctly to match the query head structure for GQA/MQA.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Delegates to forward_native.
- forward_tpu(*args, **kwargs) AttentionOutput[source]#
TPU forward pass. Delegates to forward_native.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.