easydel.layers.attention_operator._attention_impl#
- class easydel.layers.attention_operator._attention_impl.AttentionImpl(metadata: AttentionMetadata)[source]#
Bases:
ABC- create_stable_sharding(state_ps: Optional[PartitionSpec] = None, preserved_indices: List[int] = None, clone_ps: Optional[PartitionSpec] = None, dep: Optional[Union[PartitionSpec, bool]] = True) Optional[PartitionSpec][source]#
- abstract forward_cpu(*args, **kwargs) AttentionOutput[source]#
- abstract forward_cuda(*args, **kwargs) AttentionOutput[source]#
- abstract forward_gpu(*args, **kwargs) AttentionOutput[source]#
- abstract forward_native(*args, **kwargs) AttentionOutput[source]#
- abstract forward_rocm(*args, **kwargs) AttentionOutput[source]#
- abstract forward_tpu(*args, **kwargs) AttentionOutput[source]#
- abstract get_impl_metadata() AttentionMetadata[source]#
- get_runtime_type(q: Array, BTHD: bool = True) RuntimeType[source]#
- class easydel.layers.attention_operator._attention_impl.AttentionMetadata(runtime_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType], runtime_softmax_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, NoneType] = None, sequence_axis_name: str = Ellipsis, mesh: Optional[jax._src.mesh.Mesh] = Ellipsis, platform: easydel.infra.etils.EasyDeLPlatforms = Ellipsis, backend: easydel.infra.etils.EasyDeLBackends = Ellipsis, partition_axis: eformer.escale.partition.constraints.PartitionAxis = Ellipsis, base_config: Optional[easydel.infra.base_config.EasyDeLBaseConfig] = None, scan_ring_attention: bool = Ellipsis, softmax_scale: float = Ellipsis, dropout_prob: float = Ellipsis, blocksize_q: int = Ellipsis, blocksize_k: int = Ellipsis, blocksize_b: int = Ellipsis)[source]#
Bases:
object- backend: EasyDeLBackends = Ellipsis#
- base_config: Optional[EasyDeLBaseConfig] = None#
- blocksize_b: int = Ellipsis#
- blocksize_k: int = Ellipsis#
- blocksize_q: int = Ellipsis#
- dropout_prob: float = Ellipsis#
- classmethod from_config(config: EasyDeLBaseConfig, softmax_scale: float, dropout_prob: float = 0.0)[source]#
- get_partition_specs(mode: RuntimeType, BTHD: bool = True)[source]#
- partition_axis: PartitionAxis = Ellipsis#
- platform: EasyDeLPlatforms = Ellipsis#
- scan_ring_attention: bool = Ellipsis#
- sequence_axis_name: str = Ellipsis#
- set_attrs_carefully(attr_name: str, default: Optional[Any], pickup_name: Optional[str] = None, use_base_config: bool = True)[source]#
- softmax_scale: float = Ellipsis#
- class easydel.layers.attention_operator._attention_impl.AttentionOutput(attention_weights: Optional[jax.Array] = None, attention_outputs: Optional[jax.Array] = None)[source]#
Bases:
object
- class easydel.layers.attention_operator._attention_impl.AttentionRegistry[source]#
Bases:
objectRegistry for attention implementations.
- classmethod create(impl_name: str, metadata: AttentionMetadata) AttentionImpl[source]#
Create an instance of an attention implementation by name.
- classmethod get(impl_name: str) Type[AttentionImpl][source]#
Get an attention implementation by name.
- classmethod list_implementations() List[str][source]#
List all registered attention implementations.
- classmethod register(impl_cls: Type[AttentionImpl]) Type[AttentionImpl][source]#
Decorator to register an attention implementation.
Example usage:
@AttentionRegistry.register class CustomAttention(AttentionImpl):
…