easydel.layers.attention_operator.__init__#
- class easydel.layers.attention_operator.__init__.AttentionImpl(metadata: AttentionMetadata)[source]#
Bases:
ABC- create_stable_sharding(state_ps: Optional[PartitionSpec] = None, preserved_indices: List[int] = None, clone_ps: Optional[PartitionSpec] = None, dep: Optional[Union[PartitionSpec, bool]] = True) Optional[PartitionSpec][source]#
- abstract forward_cpu(*args, **kwargs) AttentionOutput[source]#
- abstract forward_cuda(*args, **kwargs) AttentionOutput[source]#
- abstract forward_gpu(*args, **kwargs) AttentionOutput[source]#
- abstract forward_native(*args, **kwargs) AttentionOutput[source]#
- abstract forward_rocm(*args, **kwargs) AttentionOutput[source]#
- abstract forward_tpu(*args, **kwargs) AttentionOutput[source]#
- abstract get_impl_metadata() AttentionMetadata[source]#
- get_runtime_type(q: Array, BTHD: bool = True) RuntimeType[source]#
- class easydel.layers.attention_operator.__init__.AttentionMetadata(runtime_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType], runtime_softmax_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, NoneType] = None, sequence_axis_name: str = Ellipsis, mesh: Optional[jax._src.mesh.Mesh] = Ellipsis, platform: easydel.infra.etils.EasyDeLPlatforms = Ellipsis, backend: easydel.infra.etils.EasyDeLBackends = Ellipsis, partition_axis: eformer.escale.partition.constraints.PartitionAxis = Ellipsis, base_config: Optional[easydel.infra.base_config.EasyDeLBaseConfig] = None, scan_ring_attention: bool = Ellipsis, softmax_scale: float = Ellipsis, dropout_prob: float = Ellipsis, blocksize_q: int = Ellipsis, blocksize_k: int = Ellipsis, blocksize_b: int = Ellipsis)[source]#
Bases:
object- backend: EasyDeLBackends = Ellipsis#
- base_config: Optional[EasyDeLBaseConfig] = None#
- blocksize_b: int = Ellipsis#
- blocksize_k: int = Ellipsis#
- blocksize_q: int = Ellipsis#
- dropout_prob: float = Ellipsis#
- classmethod from_config(config: EasyDeLBaseConfig, softmax_scale: float, dropout_prob: float = 0.0)[source]#
- get_partition_specs(mode: RuntimeType, BTHD: bool = True)[source]#
- partition_axis: PartitionAxis = Ellipsis#
- platform: EasyDeLPlatforms = Ellipsis#
- scan_ring_attention: bool = Ellipsis#
- sequence_axis_name: str = Ellipsis#
- set_attrs_carefully(attr_name: str, default: Optional[Any], pickup_name: Optional[str] = None, use_base_config: bool = True)[source]#
- softmax_scale: float = Ellipsis#
- class easydel.layers.attention_operator.__init__.AttentionOutput(attention_weights: Optional[jax.Array] = None, attention_outputs: Optional[jax.Array] = None)[source]#
Bases:
object
- class easydel.layers.attention_operator.__init__.AttentionRegistry[source]#
Bases:
objectRegistry for attention implementations.
- classmethod create(impl_name: str, metadata: AttentionMetadata) AttentionImpl[source]#
Create an instance of an attention implementation by name.
- classmethod get(impl_name: str) Type[AttentionImpl][source]#
Get an attention implementation by name.
- classmethod list_implementations() List[str][source]#
List all registered attention implementations.
- classmethod register(impl_cls: Type[AttentionImpl]) Type[AttentionImpl][source]#
Decorator to register an attention implementation.
Example usage:
@AttentionRegistry.register class CustomAttention(AttentionImpl):
…
- class easydel.layers.attention_operator.__init__.FlashAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImpl- forward_cpu(*args, **kwargs) AttentionOutput[source]#
- forward_cuda(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
- forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
- get_impl_metadata() AttentionMetadata[source]#
- class easydel.layers.attention_operator.__init__.RingAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImpl- forward_cpu(*args, **kwargs) AttentionOutput[source]#
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: Optional[PRNGKey] = None, causal: bool = True, **ignore) AttentionOutput[source]#
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
- forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: Optional[PRNGKey] = None, causal: bool = True, **ignore) AttentionOutput[source]#
- get_impl_metadata() AttentionMetadata[source]#
- class easydel.layers.attention_operator.__init__.ScaledDotProductAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImpl- forward_cpu(*args, **kwargs) AttentionOutput[source]#
- forward_cuda(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False) AttentionOutput[source]#
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
- forward_tpu(*args, **kwargs) AttentionOutput[source]#
- get_impl_metadata() AttentionMetadata[source]#
- class easydel.layers.attention_operator.__init__.SplashAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImpl- forward_cpu(*args, **kwargs) AttentionOutput[source]#
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
- forward_native(*args, **kwargs) AttentionOutput[source]#
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
- forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, **ignore) AttentionOutput[source]#
- get_impl_metadata() AttentionMetadata[source]#
- class easydel.layers.attention_operator.__init__.VanillaAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImpl- forward_cpu(*args, **kwargs) AttentionOutput[source]#
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: Optional[PRNGKey] = None, **ignore) AttentionOutput[source]#
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
- forward_tpu(*args, **kwargs) AttentionOutput[source]#
- get_impl_metadata() AttentionMetadata[source]#