easydel.layers.attention_operator._attention_impl#
- class easydel.layers.attention_operator._attention_impl.AttentionImpl(metadata: AttentionMetadata)[source]#
Bases:
BaseOperationAbstract Base Class for specific attention implementations.
Inherits from BaseOperation to leverage backend-specific dispatching. Subclasses must implement the core attention logic (forward_native) and potentially provide optimized versions for TPU (forward_tpu), GPU (forward_gpu), etc. They also need to declare their name and associated metadata.
Provides common helper methods for attention processing like mask manipulation, head repeating (for GQA/MQA), and determining runtime mode.
- create_stable_sharding(state_ps: Optional[PartitionSpec] = None, preserved_indices: List[int] = None, clone_ps: Optional[PartitionSpec] = None, dep: Optional[Union[PartitionSpec, bool]] = True, tensor: Optional[Array] = None) Optional[PartitionSpec][source]#
Helper to create a PartitionSpec, potentially preserving only certain axes.
This might be used for ensuring intermediate tensors or states have compatible sharding, possibly replicating across axes not specified in preserved_indices.
- Parameters
state_ps – The base PartitionSpec to modify.
preserved_indices – A list of dimension indices whose partitioning should be kept from state_ps (or clone_ps if provided). Other dimensions will be set to None (replicated). If None, state_ps is returned.
clone_ps – An optional PartitionSpec to copy axis names from for the preserved indices, instead of using state_ps.
dep – A dependency flag or PartitionSpec. If None, returns None. Defaults to True. (The exact purpose might be context-specific, potentially for control flow).
- Returns
A new PartitionSpec with only specified axes partitioned, or None based on dep. Returns state_ps directly if preserved_indices is None.
- current_backend() Literal['tpu', 'gpu', 'cpu'][source]#
Returns the current JAX default backend as a lowercase string literal.
- Returns
“tpu”, “gpu”, or “cpu”.
- abstract get_impl_metadata() AttentionMetadata[source]#
Returns the AttentionMetadata associated with this implementation instance.
- Returns
The AttentionMetadata instance passed during initialization.
- abstract classmethod get_impl_name() Union[str, Tuple[str, ...]][source]#
Returns the unique name(s) identifying this attention implementation.
Used by the AttentionRegistry. Can return a single string or a tuple/list of strings if the implementation has multiple aliases.
- Returns
A string or tuple/list of strings representing the implementation name(s).
- get_mode(q: Array, BTHD: bool = True) Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'][source]#
Determines the runtime mode (normal or generation) based on query shape.
Assumes generation mode if the query sequence length dimension is 1.
- Parameters
q – The query tensor.
BTHD – Boolean indicating tensor layout (True for B, T, H, D; False for B, H, T, D).
- static repeat_kv_heads(k: Array, v: Array, num_reps: int) Tuple[Array, Array][source]#
Repeats Key and Value heads for Grouped Query Attention (GQA) or Multi-Query Attention (MQA).
Expands the head dimension of K and V tensors to match the number of query heads.
- Parameters
k – Key tensor, assumes shape (batch, seq_len, num_kv_heads, head_dim).
v – Value tensor, assumes shape (batch, seq_len, num_kv_heads, head_dim).
num_reps – The number of times to repeat each KV head (num_q_heads // num_kv_heads).
- Returns
A tuple (k_repeated, v_repeated) with shapes (batch, seq_len, num_q_heads, head_dim).
- class easydel.layers.attention_operator._attention_impl.AttentionMetadata(runtime_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType], runtime_softmax_dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = None, sequence_axis_name: str = <eformer.common_types._Empty object>, mesh: ~typing.Optional[~jax._src.mesh.Mesh] = <eformer.common_types._Empty object>, platform: ~easydel.infra.etils.EasyDeLPlatforms = <eformer.common_types._Empty object>, backend: ~easydel.infra.etils.EasyDeLBackends = <eformer.common_types._Empty object>, partition_axis: ~eformer.escale.partition.manager.PartitionAxis = <eformer.common_types._Empty object>, partition_manager: ~eformer.escale.partition.manager.PartitionManager = <eformer.common_types._Empty object>, base_config: ~typing.Optional[~easydel.infra.base_config.EasyDeLBaseConfig] = None, scan_ring_attention: bool = <eformer.common_types._Empty object>, softmax_scale: float = <eformer.common_types._Empty object>, dropout_prob: float = <eformer.common_types._Empty object>, blocksize_q: int = <eformer.common_types._Empty object>, blocksize_k: int = <eformer.common_types._Empty object>, blocksize_b: int = <eformer.common_types._Empty object>)[source]#
Bases:
objectHolds configuration, context, and metadata for attention operations.
This class centralizes various parameters needed by different attention implementations, facilitating consistent behavior and configuration. It handles default values and can be initialized from an EasyDeLBaseConfig.
- runtime_dtype#
The primary JAX dtype for computations (e.g., q, k, v).
- Type
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]
- runtime_softmax_dtype#
Optional JAX dtype for the softmax computation, allowing for higher precision if needed (e.g., float32).
- Type
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]]
- sequence_axis_name#
The name used for the sequence axis in JAX parallelism (axis_names for pjit).
- Type
str
- mesh#
The JAX device mesh for distributed computation. Must be provided or inferred from context.
- Type
Optional[jax._src.mesh.Mesh]
- platform#
The target hardware platform (e.g., TPU, GPU).
- backend#
The specific JAX backend being used (e.g., TPU, CUDA, ROCM).
- partition_axis#
Configuration for partitioning axes in distributed settings. (Likely from eformer.escale).
- base_config#
An optional reference to the base model configuration object for sourcing default values.
- Type
- scan_ring_attention#
Boolean flag indicating whether to use ring attention via jax.lax.scan.
- Type
bool
- softmax_scale#
The scaling factor applied before the softmax operation. Often 1 / sqrt(head_dim).
- Type
float
- dropout_prob#
The dropout probability applied to attention weights.
- Type
float
- blocksize_q#
Block size for the query sequence dimension in blockwise attention.
- Type
int
- blocksize_k#
Block size for the key/value sequence dimension in blockwise attention.
- Type
int
- blocksize_b#
Block size for the batch dimension in blockwise attention (often 1).
- Type
int
- backend: EasyDeLBackends = <eformer.common_types._Empty object>#
- base_config: Optional[EasyDeLBaseConfig] = None#
- blocksize_b: int = <eformer.common_types._Empty object>#
- blocksize_k: int = <eformer.common_types._Empty object>#
- blocksize_q: int = <eformer.common_types._Empty object>#
- dropout_prob: float = <eformer.common_types._Empty object>#
- classmethod from_config(config: EasyDeLBaseConfig, softmax_scale: float, dropout_prob: float = 0.0) AttentionMetadata[source]#
Factory method to create AttentionMetadata from an EasyDeLBaseConfig.
- Parameters
config – The base configuration object (e.g., model config).
softmax_scale – The attention softmax scaling factor. Usually calculated based on head dimension.
dropout_prob – The attention dropout probability. Defaults to 0.0.
- Returns
An initialized AttentionMetadata instance.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- get_shardings(mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], BTHD: bool = True, qkv_mni_sharding: bool = False)[source]#
Generates JAX PartitionSpecs for attention tensors based on runtime mode.
- Parameters
mode – The current runtime mode (normal or generation).
BTHD – Boolean indicating tensor layout. True for (Batch, Time, Head, Dim), False for (Batch, Head, Time, Dim).
- Returns
(query, key, value, bias, mask, attention_output)
- Return type
A tuple containing PartitionSpecs for
- partition_axis: PartitionAxis = <eformer.common_types._Empty object>#
- partition_manager: PartitionManager = <eformer.common_types._Empty object>#
- platform: EasyDeLPlatforms = <eformer.common_types._Empty object>#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- scan_ring_attention: bool = <eformer.common_types._Empty object>#
- sequence_axis_name: str = <eformer.common_types._Empty object>#
- set_attrs_carefully(attr_name: str, default: Optional[Any], pickup_name: Optional[str] = None, use_base_config: bool = True)[source]#
Internal helper to set an attribute if it’s not already set (or is Ellipsis).
Optionally retrieves the value from self.base_config using pickup_name (or attr_name if pickup_name is None).
- Parameters
attr_name – The name of the attribute to set on self.
default – The default value to use if not found in base_config or if use_base_config is False.
pickup_name – The name of the attribute to look for in base_config. Defaults to attr_name.
use_base_config – Whether to attempt retrieving the value from base_config.
- softmax_scale: float = <eformer.common_types._Empty object>#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.attention_operator._attention_impl.AttentionOutput(attention_weights: Optional[Array] = None, attention_outputs: Optional[Array] = None)[source]#
Bases:
objectContainer for the outputs of an attention operation.
- attention_weights#
The attention probabilities, typically of shape (batch, num_heads, query_seq_len, key_value_seq_len). Optional.
- Type
Optional[jax.Array]
- attention_outputs#
The final weighted sum of values, typically of shape (batch, query_seq_len, num_heads, head_dim) or (batch, num_heads, query_seq_len, head_dim). Optional.
- Type
Optional[jax.Array]
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.attention_operator._attention_impl.AttentionRegistry[source]#
Bases:
objectRegistry for discovering and managing different AttentionImpl classes.
Allows registering implementations using a decorator and retrieving or instantiating them by name.
- classmethod create(impl_name: str, metadata: AttentionMetadata) AttentionImpl[source]#
Creates an instance of an attention implementation by name.
Retrieves the class associated with impl_name and initializes it with the provided metadata.
- Parameters
impl_name – The name of the implementation to instantiate.
metadata – The AttentionMetadata to pass to the implementation’s constructor.
- Returns
An initialized instance of the requested AttentionImpl subclass.
- Raises
ValueError – If no implementation is registered with impl_name.
- classmethod get(impl_name: str) Type[AttentionImpl][source]#
Retrieves an attention implementation class by its registered name.
- Parameters
impl_name – The name of the implementation to retrieve.
- Returns
The AttentionImpl subclass registered under the given name.
- Raises
ValueError – If no implementation is registered with that name.
- classmethod list_implementations() List[str][source]#
Returns a list of names of all registered attention implementations.
- Returns
A list of strings, where each string is a registered implementation name.
- classmethod register(impl_cls: Type[ICa]) Type[ICa][source]#
Class method decorator to register an AttentionImpl subclass.
The implementation is registered under the name(s) returned by its get_impl_name() class method.
Example: ```python @AttentionRegistry.register class FlashAttentionImpl(AttentionImpl):
@classmethod def get_impl_name(cls) -> str:
return “flash”
# … implementation …
- Parameters
impl_cls – The AttentionImpl subclass to register.
- Returns
The registered class itself.