easydel.layers.operations._base_operation#

Defines the base class for operations within EasyDeL that may have backend-specific implementations (CPU, GPU, TPU).

class easydel.layers.operations._base_operation.BaseOperation[source]#

Bases: ABC

Abstract Base Class for defining operations with potential backend-specific implementations.

This class provides a structure for defining a core operation (forward_native) and allowing optional, optimized implementations for different hardware backends supported by JAX (TPU, GPU - CUDA/ROCm, CPU).

The __call__ method acts as a dispatcher, detecting the current JAX default backend and executing the corresponding forward_… method. If a specific backend implementation (e.g., forward_tpu) is not overridden in a subclass, it defaults to calling forward_native.

Subclasses MUST implement the forward_native method. They CAN optionally override forward_tpu, forward_gpu, forward_cpu, forward_rocm, or forward_cuda to provide backend-specific optimizations.

property EasyDeLBackends#
current_backend() Literal['tpu', 'gpu', 'cpu'][source]#

Returns the current JAX default backend as a lowercase string literal.

Returns

“tpu”, “gpu”, or “cpu”.

forward_cpu(*args, **kwargs) Any[source]#

CPU-specific implementation of the operation.

Defaults to calling forward_native. Subclasses can override this for CPU-specific optimizations (though often forward_native is sufficient).

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for CPU.

forward_cuda(*args, **kwargs) Any[source]#

CUDA (NVIDIA GPU)-specific implementation of the operation.

Defaults to calling forward_gpu. Subclasses can override this for optimizations specific to the CUDA platform, if necessary.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for CUDA GPUs.

forward_gpu(*args, **kwargs) Any[source]#

Generic GPU-specific implementation of the operation.

Defaults to calling forward_native. This method serves as the base for CUDA and ROCm backends unless they are specifically overridden. Subclasses can override this for general GPU optimizations.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for GPUs.

abstract forward_native(*args, **kwargs) Any[source]#

The core, backend-agnostic implementation of the operation.

This method MUST be implemented by any concrete subclass of BaseOperation. It serves as the default implementation if no backend-specific override is available or applicable.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation. Type depends on the specific operation.

forward_rocm(*args, **kwargs) Any[source]#

ROCm (AMD GPU)-specific implementation of the operation.

Defaults to calling forward_gpu. Subclasses can override this for optimizations specific to the ROCm platform, if necessary.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for ROCm GPUs.

forward_tpu(*args, **kwargs) Any[source]#

TPU-specific implementation of the operation.

Defaults to calling forward_native. Subclasses can override this for TPU-specific optimizations.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for TPU.

forward_tt(*args, **kwargs) Any[source]#

TT-specific implementation of the operation.

Defaults to calling forward_native. Subclasses can override this for TT-specific optimizations.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for TT.

abstract get_impl_metadata() OperationMetadata[source]#

Returns the OperationMetadata associated with this implementation instance.

Returns

The OperationMetadata instance passed during initialization.

abstract classmethod get_impl_name() str | tuple[str, ...][source]#

Returns the unique name(s) identifying this attention implementation.

Used by the OperationRegistry. Can return a single string or a tuple/list of strings if the implementation has multiple aliases.

Returns

A string or tuple/list of strings representing the implementation name(s).

metadata: easydel.layers.operations._operation_meta.OperationMetadata | None = None#
class easydel.layers.operations._base_operation.OperationRegistry[source]#

Bases: object

Registry for discovering and managing different OperationImpl classes.

Allows registering implementations using a decorator and retrieving or instantiating them by name.

classmethod create(impl_name: str, metadata: OperationMetadata) BaseOperation[source]#

Creates an instance of an attention implementation by name.

Retrieves the class associated with impl_name and initializes it with the provided metadata.

Parameters
  • impl_name – The name of the implementation to instantiate.

  • metadata – The OperationMetadata to pass to the implementation’s constructor.

Returns

An initialized instance of the requested OperationImpl subclass.

Raises

ValueError – If no implementation is registered with impl_name.

classmethod get(impl_name: str) type[easydel.layers.operations._base_operation.BaseOperation][source]#

Retrieves an attention implementation class by its registered name.

Parameters

impl_name – The name of the implementation to retrieve.

Returns

The OperationImpl subclass registered under the given name.

Raises

ValueError – If no implementation is registered with that name.

classmethod list_implementations() list[str][source]#

Returns a list of names of all registered attention implementations.

Returns

A list of strings, where each string is a registered implementation name.

classmethod register(impl_cls: type[ICa]) type[ICa][source]#

Class method decorator to register an OperationImpl subclass.

The implementation is registered under the name(s) returned by its get_impl_name() class method.

Example: ```python @OperationRegistry.register class FlashOperationImpl(OperationImpl):

@classmethod def get_impl_name(cls) -> str:

return “flash”

# … implementation …

```

Parameters

impl_cls – The OperationImpl subclass to register.

Returns

The registered class itself.