easydel.layers.operations._operation_meta

Contents

easydel.layers.operations._operation_meta#

class easydel.layers.operations._operation_meta.AttnShardingRules(query3d: jax.sharding.PartitionSpec, query: jax.sharding.PartitionSpec, key: jax.sharding.PartitionSpec, value: jax.sharding.PartitionSpec, bias: jax.sharding.PartitionSpec, mask: jax.sharding.PartitionSpec, output: jax.sharding.PartitionSpec, q_segment_ids: jax.sharding.PartitionSpec, kv_segment_ids: jax.sharding.PartitionSpec, softmax_aux: jax.sharding.PartitionSpec | None)[source]#

Bases: NamedTuple

Named tuple containing JAX PartitionSpecs for all attention tensors.

query3d#

Sharding for a 3d query tensor which is [b, h, d].

Type

jax.sharding.PartitionSpec

query#

Sharding for query tensor.

Type

jax.sharding.PartitionSpec

key#

Sharding for key tensor.

Type

jax.sharding.PartitionSpec

value#

Sharding for value tensor.

Type

jax.sharding.PartitionSpec

bias#

Sharding for attention bias tensor.

Type

jax.sharding.PartitionSpec

mask#

Sharding for attention mask tensor.

Type

jax.sharding.PartitionSpec

output#

Sharding for attention output tensor.

Type

jax.sharding.PartitionSpec

q_segment_ids#

Sharding for query segment IDs (for packed sequences).

Type

jax.sharding.PartitionSpec

kv_segment_ids#

Sharding for key/value segment IDs (for packed sequences).

Type

jax.sharding.PartitionSpec

softmax_aux#

Optional sharding for 2D softmax auxiliary outputs (e.g., LSE, max).

Type

jax.sharding.PartitionSpec | None

bias: PartitionSpec#

Alias for field number 4

key: PartitionSpec#

Alias for field number 2

kv_segment_ids: PartitionSpec#

Alias for field number 8

mask: PartitionSpec#

Alias for field number 5

output: PartitionSpec#

Alias for field number 6

q_segment_ids: PartitionSpec#

Alias for field number 7

query: PartitionSpec#

Alias for field number 1

query3d: PartitionSpec#

Alias for field number 0

softmax_aux: jax.sharding.PartitionSpec | None#

Alias for field number 9

value: PartitionSpec#

Alias for field number 3

class easydel.layers.operations._operation_meta.OperationMetadata(runtime_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType], runtime_softmax_dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = None, sequence_axis_name: str = <eformer.common_types._Empty object>, platform: enum.Enum | str = <eformer.common_types._Empty object>, backend: enum.Enum | str = <eformer.common_types._Empty object>, partition_axis: ~eformer.escale.partition.manager.PartitionAxis = <eformer.common_types._Empty object>, partition_manager: ~eformer.escale.partition.manager.PartitionManager = <eformer.common_types._Empty object>, base_config: object | None = None, operation_configs: dict[str, object] | None = None, _stored_mesh: jax._src.mesh.Mesh | None = <eformer.common_types._Empty object>)[source]#

Bases: object

Holds configuration, context, and metadata for attention operations.

This class centralizes various parameters needed by different attention implementations, facilitating consistent behavior and configuration. It handles default values and can be initialized from an EasyDeLBaseConfig.

runtime_dtype#

The primary JAX dtype for computations (e.g., q, k, v).

Type

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]

runtime_softmax_dtype#

Optional JAX dtype for the softmax computation, allowing for higher precision if needed (e.g., float32).

Type

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]]

sequence_axis_name#

The name used for the sequence axis in JAX parallelism (sharding_axis_names for pjit).

Type

str

mesh#

The JAX device mesh for distributed computation. Must be provided or inferred from context.

platform#

The target hardware platform (e.g., TPU, GPU).

Type

enum.Enum | str

backend#

The specific JAX backend being used (e.g., TPU, CUDA, ROCM).

Type

enum.Enum | str

partition_axis#

Configuration for partitioning axes in distributed settings. (Likely from eformer.escale).

Type

eformer.escale.partition.manager.PartitionAxis

base_config#

An optional reference to the base model configuration object for sourcing default values.

Type

object | None

scan_ring_attention#

Boolean flag indicating whether to use ring attention via jax.lax.scan.

softmax_scale#

The scaling factor applied before the softmax operation. Often 1 / sqrt(head_dim).

dropout_prob#

The dropout probability applied to attention weights.

blocksize_q#

Block size for the query sequence dimension in blockwise attention.

blocksize_k#

Block size for the key/value sequence dimension in blockwise attention.

blocksize_b#

Block size for the batch dimension in blockwise attention (often 1).

backend: enum.Enum | str = <eformer.common_types._Empty object>#
base_config: object | None = None#
classmethod from_config(config: object) OperationMetadata[source]#

Factory method to create OperationMetadata from an EasyDeLBaseConfig.

Parameters
  • config – The base configuration object (e.g., model config).

  • softmax_scale – The attention softmax scaling factor. Usually calculated based on head dimension.

  • dropout_prob – The attention dropout probability. Defaults to 0.0.

Returns

An initialized OperationMetadata instance.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

get_operation_config(impl_name: str) object | None[source]#

Get ejkernel config for a specific operation by its registered name.

Parameters

impl_name – The operation implementation name (must match OperationRegistry). Valid names: - “flash_attn2”: Flash attention 2 implementation - “ring”: Ring attention - “blocksparse”: Block sparse attention - “ragged_page_attention_v2”: Ragged page attention v2 - “ragged_page_attention_v3”: Ragged page attention v3 - “sdpa”: Scaled dot product attention - “vanilla”: Vanilla attention

Returns

The operation config if set, otherwise None (which enables ejkernel autotune).

Example

>>> cfg = metadata.get_operation_config("flash_attn2")
>>> if cfg is not None:
...     # Use explicit config
...     flash_attention(..., cfg=cfg)
>>> else:
...     # Use autotune
...     flash_attention(..., cfg=None)
get_shardings(mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], layout: Literal['bthd', 'bhtd', 'thd'] = 'bthd', qkv_mni_sharding: bool = False, softmax_aux: jax.Array | None = None) AttnShardingRules[source]#

Generates JAX PartitionSpecs for attention tensors based on runtime mode.

Parameters
  • mode – Runtime mode (e.g., training, inference) for partition resolution.

  • layout – Tensor layout format - “bthd” (batch, time, heads, dim) or “bhtd” (batch, heads, time, dim).

  • qkv_mni_sharding – If True, use HEAD/HEAD_DIM for K/V instead of KV_HEAD/KV_HEAD_DIM. Useful for multi-head attention (MHA) vs grouped-query attention (GQA/MQA).

  • softmax_aux_2d – If True, create sharding for 2D softmax auxiliary outputs (e.g., log-sum-exp, max values) with shape [batch, num_heads].

Returns

Named tuple containing PartitionSpecs for all attention tensors:
  • query, key, value: Main attention tensors

  • bias: Attention bias tensor

  • mask: Attention mask tensor

  • output: Attention output tensor

  • q_segment_ids: Query segment IDs (for packed sequences)

  • kv_segment_ids: Key/value segment IDs (for packed sequences)

  • softmax_aux: Optional 2D softmax auxiliary output sharding

Return type

AttnShardingRules

property mesh: jax._src.mesh.Mesh | None#

Get current mesh from base_config if available, otherwise return stored mesh.

operation_configs: dict[str, object] | None = None#
partition_axis: PartitionAxis = <eformer.common_types._Empty object>#
partition_manager: PartitionManager = <eformer.common_types._Empty object>#
platform: enum.Enum | str = <eformer.common_types._Empty object>#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

runtime_dtype: Union[str, type[Any], dtype, SupportsDType]#
runtime_softmax_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None#
sequence_axis_name: str = <eformer.common_types._Empty object>#
set_attrs_carefully(attr_name: str, default: Any | None, pickup_name: str | None = None, use_base_config: bool = True) None[source]#

Internal helper to set an attribute if it’s not already set (or is Ellipsis).

Optionally retrieves the value from self.base_config using pickup_name (or attr_name if pickup_name is None).

Parameters
  • attr_name – The name of the attribute to set on self.

  • default – The default value to use if not found in base_config or if use_base_config is False.

  • pickup_name – The name of the attribute to look for in base_config. Defaults to attr_name.

  • use_base_config – Whether to attempt retrieving the value from base_config.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.