easydel.layers.attention_operator._attention_impl

Contents

easydel.layers.attention_operator._attention_impl#

class easydel.layers.attention_operator._attention_impl.AttentionImpl(metadata: AttentionMetadata)[source]#

Bases: BaseOperation

Abstract Base Class for specific attention implementations.

Inherits from BaseOperation to leverage backend-specific dispatching. Subclasses must implement the core attention logic (forward_native) and potentially provide optimized versions for TPU (forward_tpu), GPU (forward_gpu), etc. They also need to declare their name and associated metadata.

Provides common helper methods for attention processing like mask manipulation, head repeating (for GQA/MQA), and determining runtime mode.

create_stable_sharding(state_ps: Optional[PartitionSpec] = None, preserved_indices: List[int] = None, clone_ps: Optional[PartitionSpec] = None, dep: Optional[Union[PartitionSpec, bool]] = True, tensor: Optional[Array] = None) Optional[PartitionSpec][source]#

Helper to create a PartitionSpec, potentially preserving only certain axes.

This might be used for ensuring intermediate tensors or states have compatible sharding, possibly replicating across axes not specified in preserved_indices.

Parameters
  • state_ps – The base PartitionSpec to modify.

  • preserved_indices – A list of dimension indices whose partitioning should be kept from state_ps (or clone_ps if provided). Other dimensions will be set to None (replicated). If None, state_ps is returned.

  • clone_ps – An optional PartitionSpec to copy axis names from for the preserved indices, instead of using state_ps.

  • dep – A dependency flag or PartitionSpec. If None, returns None. Defaults to True. (The exact purpose might be context-specific, potentially for control flow).

Returns

A new PartitionSpec with only specified axes partitioned, or None based on dep. Returns state_ps directly if preserved_indices is None.

current_backend() Literal['tpu', 'gpu', 'cpu'][source]#

Returns the current JAX default backend as a lowercase string literal.

Returns

“tpu”, “gpu”, or “cpu”.

abstract get_impl_metadata() AttentionMetadata[source]#

Returns the AttentionMetadata associated with this implementation instance.

Returns

The AttentionMetadata instance passed during initialization.

abstract classmethod get_impl_name() Union[str, Tuple[str, ...]][source]#

Returns the unique name(s) identifying this attention implementation.

Used by the AttentionRegistry. Can return a single string or a tuple/list of strings if the implementation has multiple aliases.

Returns

A string or tuple/list of strings representing the implementation name(s).

get_mode(q: Array, BTHD: bool = True) Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'][source]#

Determines the runtime mode (normal or generation) based on query shape.

Assumes generation mode if the query sequence length dimension is 1.

Parameters
  • q – The query tensor.

  • BTHD – Boolean indicating tensor layout (True for B, T, H, D; False for B, H, T, D).

static repeat_kv_heads(k: Array, v: Array, num_reps: int) Tuple[Array, Array][source]#

Repeats Key and Value heads for Grouped Query Attention (GQA) or Multi-Query Attention (MQA).

Expands the head dimension of K and V tensors to match the number of query heads.

Parameters
  • k – Key tensor, assumes shape (batch, seq_len, num_kv_heads, head_dim).

  • v – Value tensor, assumes shape (batch, seq_len, num_kv_heads, head_dim).

  • num_reps – The number of times to repeat each KV head (num_q_heads // num_kv_heads).

Returns

A tuple (k_repeated, v_repeated) with shapes (batch, seq_len, num_q_heads, head_dim).

class easydel.layers.attention_operator._attention_impl.AttentionMetadata(runtime_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType], runtime_softmax_dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = None, sequence_axis_name: str = <eformer.common_types._Empty object>, mesh: ~typing.Optional[~jax._src.mesh.Mesh] = <eformer.common_types._Empty object>, platform: ~easydel.infra.etils.EasyDeLPlatforms = <eformer.common_types._Empty object>, backend: ~easydel.infra.etils.EasyDeLBackends = <eformer.common_types._Empty object>, partition_axis: ~eformer.escale.partition.manager.PartitionAxis = <eformer.common_types._Empty object>, partition_manager: ~eformer.escale.partition.manager.PartitionManager = <eformer.common_types._Empty object>, base_config: ~typing.Optional[~easydel.infra.base_config.EasyDeLBaseConfig] = None, scan_ring_attention: bool = <eformer.common_types._Empty object>, softmax_scale: float = <eformer.common_types._Empty object>, dropout_prob: float = <eformer.common_types._Empty object>, blocksize_q: int = <eformer.common_types._Empty object>, blocksize_k: int = <eformer.common_types._Empty object>, blocksize_b: int = <eformer.common_types._Empty object>)[source]#

Bases: object

Holds configuration, context, and metadata for attention operations.

This class centralizes various parameters needed by different attention implementations, facilitating consistent behavior and configuration. It handles default values and can be initialized from an EasyDeLBaseConfig.

runtime_dtype#

The primary JAX dtype for computations (e.g., q, k, v).

Type

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]

runtime_softmax_dtype#

Optional JAX dtype for the softmax computation, allowing for higher precision if needed (e.g., float32).

Type

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]]

sequence_axis_name#

The name used for the sequence axis in JAX parallelism (axis_names for pjit).

Type

str

mesh#

The JAX device mesh for distributed computation. Must be provided or inferred from context.

Type

Optional[jax._src.mesh.Mesh]

platform#

The target hardware platform (e.g., TPU, GPU).

Type

easydel.infra.etils.EasyDeLPlatforms

backend#

The specific JAX backend being used (e.g., TPU, CUDA, ROCM).

Type

easydel.infra.etils.EasyDeLBackends

partition_axis#

Configuration for partitioning axes in distributed settings. (Likely from eformer.escale).

Type

eformer.escale.partition.manager.PartitionAxis

base_config#

An optional reference to the base model configuration object for sourcing default values.

Type

Optional[easydel.infra.base_config.EasyDeLBaseConfig]

scan_ring_attention#

Boolean flag indicating whether to use ring attention via jax.lax.scan.

Type

bool

softmax_scale#

The scaling factor applied before the softmax operation. Often 1 / sqrt(head_dim).

Type

float

dropout_prob#

The dropout probability applied to attention weights.

Type

float

blocksize_q#

Block size for the query sequence dimension in blockwise attention.

Type

int

blocksize_k#

Block size for the key/value sequence dimension in blockwise attention.

Type

int

blocksize_b#

Block size for the batch dimension in blockwise attention (often 1).

Type

int

backend: EasyDeLBackends = <eformer.common_types._Empty object>#
base_config: Optional[EasyDeLBaseConfig] = None#
blocksize_b: int = <eformer.common_types._Empty object>#
blocksize_k: int = <eformer.common_types._Empty object>#
blocksize_q: int = <eformer.common_types._Empty object>#
dropout_prob: float = <eformer.common_types._Empty object>#
classmethod from_config(config: EasyDeLBaseConfig, softmax_scale: float, dropout_prob: float = 0.0) AttentionMetadata[source]#

Factory method to create AttentionMetadata from an EasyDeLBaseConfig.

Parameters
  • config – The base configuration object (e.g., model config).

  • softmax_scale – The attention softmax scaling factor. Usually calculated based on head dimension.

  • dropout_prob – The attention dropout probability. Defaults to 0.0.

Returns

An initialized AttentionMetadata instance.

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

get_shardings(mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], BTHD: bool = True, qkv_mni_sharding: bool = False)[source]#

Generates JAX PartitionSpecs for attention tensors based on runtime mode.

Parameters
  • mode – The current runtime mode (normal or generation).

  • BTHD – Boolean indicating tensor layout. True for (Batch, Time, Head, Dim), False for (Batch, Head, Time, Dim).

Returns

(query, key, value, bias, mask, attention_output)

Return type

A tuple containing PartitionSpecs for

mesh: Optional[Mesh] = <eformer.common_types._Empty object>#
partition_axis: PartitionAxis = <eformer.common_types._Empty object>#
partition_manager: PartitionManager = <eformer.common_types._Empty object>#
platform: EasyDeLPlatforms = <eformer.common_types._Empty object>#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

runtime_dtype: Union[str, type[Any], dtype, SupportsDType]#
runtime_softmax_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None#
scan_ring_attention: bool = <eformer.common_types._Empty object>#
sequence_axis_name: str = <eformer.common_types._Empty object>#
set_attrs_carefully(attr_name: str, default: Optional[Any], pickup_name: Optional[str] = None, use_base_config: bool = True)[source]#

Internal helper to set an attribute if it’s not already set (or is Ellipsis).

Optionally retrieves the value from self.base_config using pickup_name (or attr_name if pickup_name is None).

Parameters
  • attr_name – The name of the attribute to set on self.

  • default – The default value to use if not found in base_config or if use_base_config is False.

  • pickup_name – The name of the attribute to look for in base_config. Defaults to attr_name.

  • use_base_config – Whether to attempt retrieving the value from base_config.

softmax_scale: float = <eformer.common_types._Empty object>#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.attention_operator._attention_impl.AttentionOutput(attention_weights: Optional[Array] = None, attention_outputs: Optional[Array] = None)[source]#

Bases: object

Container for the outputs of an attention operation.

attention_weights#

The attention probabilities, typically of shape (batch, num_heads, query_seq_len, key_value_seq_len). Optional.

Type

Optional[jax.Array]

attention_outputs#

The final weighted sum of values, typically of shape (batch, query_seq_len, num_heads, head_dim) or (batch, num_heads, query_seq_len, head_dim). Optional.

Type

Optional[jax.Array]

attention_outputs: Optional[Array] = None#
attention_weights: Optional[Array] = None#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.attention_operator._attention_impl.AttentionRegistry[source]#

Bases: object

Registry for discovering and managing different AttentionImpl classes.

Allows registering implementations using a decorator and retrieving or instantiating them by name.

classmethod create(impl_name: str, metadata: AttentionMetadata) AttentionImpl[source]#

Creates an instance of an attention implementation by name.

Retrieves the class associated with impl_name and initializes it with the provided metadata.

Parameters
  • impl_name – The name of the implementation to instantiate.

  • metadata – The AttentionMetadata to pass to the implementation’s constructor.

Returns

An initialized instance of the requested AttentionImpl subclass.

Raises

ValueError – If no implementation is registered with impl_name.

classmethod get(impl_name: str) Type[AttentionImpl][source]#

Retrieves an attention implementation class by its registered name.

Parameters

impl_name – The name of the implementation to retrieve.

Returns

The AttentionImpl subclass registered under the given name.

Raises

ValueError – If no implementation is registered with that name.

classmethod list_implementations() List[str][source]#

Returns a list of names of all registered attention implementations.

Returns

A list of strings, where each string is a registered implementation name.

classmethod register(impl_cls: Type[ICa]) Type[ICa][source]#

Class method decorator to register an AttentionImpl subclass.

The implementation is registered under the name(s) returned by its get_impl_name() class method.

Example: ```python @AttentionRegistry.register class FlashAttentionImpl(AttentionImpl):

@classmethod def get_impl_name(cls) -> str:

return “flash”

# … implementation …

```

Parameters

impl_cls – The AttentionImpl subclass to register.

Returns

The registered class itself.