easydel.layers.attention_operator._attention_impl#

class easydel.layers.attention_operator._attention_impl.AttentionImpl(metadata: AttentionMetadata)[source]#

Bases: ABC

create_stable_sharding(state_ps: Optional[PartitionSpec] = None, preserved_indices: List[int] = None, clone_ps: Optional[PartitionSpec] = None, dep: Optional[Union[PartitionSpec, bool]] = True) Optional[PartitionSpec][source]#
current_backend() Literal['tpu', 'gpu', 'cpu'][source]#
abstract forward_cpu(*args, **kwargs) AttentionOutput[source]#
abstract forward_cuda(*args, **kwargs) AttentionOutput[source]#
abstract forward_gpu(*args, **kwargs) AttentionOutput[source]#
abstract forward_native(*args, **kwargs) AttentionOutput[source]#
abstract forward_rocm(*args, **kwargs) AttentionOutput[source]#
abstract forward_tpu(*args, **kwargs) AttentionOutput[source]#
abstract get_impl_metadata() AttentionMetadata[source]#
abstract classmethod get_impl_name() Union[str, Tuple[str]][source]#
get_runtime_type(q: Array, BTHD: bool = True) RuntimeType[source]#
static repeat_kv_heads(k: Array, v: Array, num_reps: int) Tuple[Array, Array][source]#

Repeats k and v heads to match q heads.

class easydel.layers.attention_operator._attention_impl.AttentionMetadata(runtime_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType], runtime_softmax_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, NoneType] = None, sequence_axis_name: str = Ellipsis, mesh: Optional[jax._src.mesh.Mesh] = Ellipsis, platform: easydel.infra.etils.EasyDeLPlatforms = Ellipsis, backend: easydel.infra.etils.EasyDeLBackends = Ellipsis, partition_axis: eformer.escale.partition.constraints.PartitionAxis = Ellipsis, base_config: Optional[easydel.infra.base_config.EasyDeLBaseConfig] = None, scan_ring_attention: bool = Ellipsis, softmax_scale: float = Ellipsis, dropout_prob: float = Ellipsis, blocksize_q: int = Ellipsis, blocksize_k: int = Ellipsis, blocksize_b: int = Ellipsis)[source]#

Bases: object

backend: EasyDeLBackends = Ellipsis#
base_config: Optional[EasyDeLBaseConfig] = None#
blocksize_b: int = Ellipsis#
blocksize_k: int = Ellipsis#
blocksize_q: int = Ellipsis#
dropout_prob: float = Ellipsis#
classmethod from_config(config: EasyDeLBaseConfig, softmax_scale: float, dropout_prob: float = 0.0)[source]#
get_partition_specs(mode: RuntimeType, BTHD: bool = True)[source]#
mesh: Optional[Mesh] = Ellipsis#
partition_axis: PartitionAxis = Ellipsis#
platform: EasyDeLPlatforms = Ellipsis#
runtime_dtype: Union[str, type[Any], dtype, SupportsDType]#
runtime_softmax_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None#
scan_ring_attention: bool = Ellipsis#
sequence_axis_name: str = Ellipsis#
set_attrs_carefully(attr_name: str, default: Optional[Any], pickup_name: Optional[str] = None, use_base_config: bool = True)[source]#
softmax_scale: float = Ellipsis#
class easydel.layers.attention_operator._attention_impl.AttentionOutput(attention_weights: Optional[jax.Array] = None, attention_outputs: Optional[jax.Array] = None)[source]#

Bases: object

attention_outputs: Optional[Array] = None#
attention_weights: Optional[Array] = None#
class easydel.layers.attention_operator._attention_impl.AttentionRegistry[source]#

Bases: object

Registry for attention implementations.

classmethod create(impl_name: str, metadata: AttentionMetadata) AttentionImpl[source]#

Create an instance of an attention implementation by name.

classmethod get(impl_name: str) Type[AttentionImpl][source]#

Get an attention implementation by name.

classmethod list_implementations() List[str][source]#

List all registered attention implementations.

classmethod register(impl_cls: Type[AttentionImpl]) Type[AttentionImpl][source]#

Decorator to register an attention implementation.

Example usage:

@AttentionRegistry.register class CustomAttention(AttentionImpl):

…

class easydel.layers.attention_operator._attention_impl.RuntimeType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: Enum

generation = 'generation'#
normal = 'normal'#