easydel.layers.ops._base_operation#
Defines the base class for operations within EasyDeL that may have backend-specific implementations (CPU, GPU, TPU).
- class easydel.layers.ops._base_operation.BaseOperation[source]#
Bases:
ABCAbstract Base Class for defining operations with potential backend-specific implementations.
This class provides a structure for defining a core operation (forward_native) and allowing optional, optimized implementations for different hardware backends supported by JAX (TPU, GPU - CUDA/ROCm, CPU).
The __call__ method acts as a dispatcher, detecting the current JAX default backend and executing the corresponding forward_… method. If a specific backend implementation (e.g., forward_tpu) is not overridden in a subclass, it defaults to calling forward_native.
Subclasses MUST implement the forward_native method. They CAN optionally override forward_tpu, forward_gpu, forward_cpu, forward_rocm, or forward_cuda to provide backend-specific optimizations.
- forward_cpu(*args, **kwargs) Any[source]#
CPU-specific implementation of the operation.
Defaults to calling forward_native. Subclasses can override this for CPU-specific optimizations (though often forward_native is sufficient).
- Parameters
*args – Positional arguments for the operation.
**kwargs – Keyword arguments for the operation.
- Returns
The result of the operation, potentially optimized for CPU.
- forward_cuda(*args, **kwargs) Any[source]#
CUDA (NVIDIA GPU)-specific implementation of the operation.
Defaults to calling forward_gpu. Subclasses can override this for optimizations specific to the CUDA platform, if necessary.
- Parameters
*args – Positional arguments for the operation.
**kwargs – Keyword arguments for the operation.
- Returns
The result of the operation, potentially optimized for CUDA GPUs.
- forward_gpu(*args, **kwargs) Any[source]#
Generic GPU-specific implementation of the operation.
Defaults to calling forward_native. This method serves as the base for CUDA and ROCm backends unless they are specifically overridden. Subclasses can override this for general GPU optimizations.
- Parameters
*args – Positional arguments for the operation.
**kwargs – Keyword arguments for the operation.
- Returns
The result of the operation, potentially optimized for GPUs.
- abstract forward_native(*args, **kwargs) Any[source]#
The core, backend-agnostic implementation of the operation.
This method MUST be implemented by any concrete subclass of BaseOperation. It serves as the default implementation if no backend-specific override is available or applicable.
- Parameters
*args – Positional arguments for the operation.
**kwargs – Keyword arguments for the operation.
- Returns
The result of the operation. Type depends on the specific operation.
- forward_rocm(*args, **kwargs) Any[source]#
ROCm (AMD GPU)-specific implementation of the operation.
Defaults to calling forward_gpu. Subclasses can override this for optimizations specific to the ROCm platform, if necessary.
- Parameters
*args – Positional arguments for the operation.
**kwargs – Keyword arguments for the operation.
- Returns
The result of the operation, potentially optimized for ROCm GPUs.
- forward_tpu(*args, **kwargs) Any[source]#
TPU-specific implementation of the operation.
Defaults to calling forward_native. Subclasses can override this for TPU-specific optimizations.
- Parameters
*args – Positional arguments for the operation.
**kwargs – Keyword arguments for the operation.
- Returns
The result of the operation, potentially optimized for TPU.