easydel.layers.operations._base_operation#
Defines the base class for operations within EasyDeL that may have backend-specific implementations (CPU, GPU, TPU).
- class easydel.layers.operations._base_operation.BaseOperation[source]#
Bases:
ABCAbstract Base Class for defining operations with potential backend-specific implementations.
This class provides a structure for defining a core operation (forward_native) and allowing optional, optimized implementations for different hardware backends supported by JAX (TPU, GPU - CUDA/ROCm, CPU).
The __call__ method acts as a dispatcher, detecting the current JAX default backend and executing the corresponding forward_… method. If a specific backend implementation (e.g., forward_tpu) is not overridden in a subclass, it defaults to calling forward_native.
Subclasses MUST implement the forward_native method. They CAN optionally override forward_tpu, forward_gpu, forward_cpu, forward_rocm, or forward_cuda to provide backend-specific optimizations.
- property EasyDeLBackends#
- current_backend() Literal['tpu', 'gpu', 'cpu'][source]#
Returns the current JAX default backend as a lowercase string literal.
- Returns
“tpu”, “gpu”, or “cpu”.
- forward_cpu(*args, **kwargs) Any[source]#
CPU-specific implementation of the operation.
Defaults to calling forward_native. Subclasses can override this for CPU-specific optimizations (though often forward_native is sufficient).
- Parameters
*args – Positional arguments for the operation.
**kwargs – Keyword arguments for the operation.
- Returns
The result of the operation, potentially optimized for CPU.
- forward_cuda(*args, **kwargs) Any[source]#
CUDA (NVIDIA GPU)-specific implementation of the operation.
Defaults to calling forward_gpu. Subclasses can override this for optimizations specific to the CUDA platform, if necessary.
- Parameters
*args – Positional arguments for the operation.
**kwargs – Keyword arguments for the operation.
- Returns
The result of the operation, potentially optimized for CUDA GPUs.
- forward_gpu(*args, **kwargs) Any[source]#
Generic GPU-specific implementation of the operation.
Defaults to calling forward_native. This method serves as the base for CUDA and ROCm backends unless they are specifically overridden. Subclasses can override this for general GPU optimizations.
- Parameters
*args – Positional arguments for the operation.
**kwargs – Keyword arguments for the operation.
- Returns
The result of the operation, potentially optimized for GPUs.
- abstract forward_native(*args, **kwargs) Any[source]#
The core, backend-agnostic implementation of the operation.
This method MUST be implemented by any concrete subclass of BaseOperation. It serves as the default implementation if no backend-specific override is available or applicable.
- Parameters
*args – Positional arguments for the operation.
**kwargs – Keyword arguments for the operation.
- Returns
The result of the operation. Type depends on the specific operation.
- forward_rocm(*args, **kwargs) Any[source]#
ROCm (AMD GPU)-specific implementation of the operation.
Defaults to calling forward_gpu. Subclasses can override this for optimizations specific to the ROCm platform, if necessary.
- Parameters
*args – Positional arguments for the operation.
**kwargs – Keyword arguments for the operation.
- Returns
The result of the operation, potentially optimized for ROCm GPUs.
- forward_tpu(*args, **kwargs) Any[source]#
TPU-specific implementation of the operation.
Defaults to calling forward_native. Subclasses can override this for TPU-specific optimizations.
- Parameters
*args – Positional arguments for the operation.
**kwargs – Keyword arguments for the operation.
- Returns
The result of the operation, potentially optimized for TPU.
- forward_tt(*args, **kwargs) Any[source]#
TT-specific implementation of the operation.
Defaults to calling forward_native. Subclasses can override this for TT-specific optimizations.
- Parameters
*args – Positional arguments for the operation.
**kwargs – Keyword arguments for the operation.
- Returns
The result of the operation, potentially optimized for TT.
- abstract get_impl_metadata() OperationMetadata[source]#
Returns the OperationMetadata associated with this implementation instance.
- Returns
The OperationMetadata instance passed during initialization.
- abstract classmethod get_impl_name() str | tuple[str, ...][source]#
Returns the unique name(s) identifying this attention implementation.
Used by the OperationRegistry. Can return a single string or a tuple/list of strings if the implementation has multiple aliases.
- Returns
A string or tuple/list of strings representing the implementation name(s).
- metadata: easydel.layers.operations._operation_meta.OperationMetadata | None = None#
- class easydel.layers.operations._base_operation.OperationRegistry[source]#
Bases:
objectRegistry for discovering and managing different OperationImpl classes.
Allows registering implementations using a decorator and retrieving or instantiating them by name.
- classmethod create(impl_name: str, metadata: OperationMetadata) BaseOperation[source]#
Creates an instance of an attention implementation by name.
Retrieves the class associated with impl_name and initializes it with the provided metadata.
- Parameters
impl_name – The name of the implementation to instantiate.
metadata – The OperationMetadata to pass to the implementation’s constructor.
- Returns
An initialized instance of the requested OperationImpl subclass.
- Raises
ValueError – If no implementation is registered with impl_name.
- classmethod get(impl_name: str) type[easydel.layers.operations._base_operation.BaseOperation][source]#
Retrieves an attention implementation class by its registered name.
- Parameters
impl_name – The name of the implementation to retrieve.
- Returns
The OperationImpl subclass registered under the given name.
- Raises
ValueError – If no implementation is registered with that name.
- classmethod list_implementations() list[str][source]#
Returns a list of names of all registered attention implementations.
- Returns
A list of strings, where each string is a registered implementation name.
- classmethod register(impl_cls: type[ICa]) type[ICa][source]#
Class method decorator to register an OperationImpl subclass.
The implementation is registered under the name(s) returned by its get_impl_name() class method.
Example: ```python @OperationRegistry.register class FlashOperationImpl(OperationImpl):
@classmethod def get_impl_name(cls) -> str:
return “flash”
# … implementation …
- Parameters
impl_cls – The OperationImpl subclass to register.
- Returns
The registered class itself.