easydel.layers.ops._base_operation#

Defines the base class for operations within EasyDeL that may have backend-specific implementations (CPU, GPU, TPU).

class easydel.layers.ops._base_operation.BaseOperation[source]#

Bases: ABC

Abstract Base Class for defining operations with potential backend-specific implementations.

This class provides a structure for defining a core operation (forward_native) and allowing optional, optimized implementations for different hardware backends supported by JAX (TPU, GPU - CUDA/ROCm, CPU).

The __call__ method acts as a dispatcher, detecting the current JAX default backend and executing the corresponding forward_… method. If a specific backend implementation (e.g., forward_tpu) is not overridden in a subclass, it defaults to calling forward_native.

Subclasses MUST implement the forward_native method. They CAN optionally override forward_tpu, forward_gpu, forward_cpu, forward_rocm, or forward_cuda to provide backend-specific optimizations.

forward_cpu(*args, **kwargs) Any[source]#

CPU-specific implementation of the operation.

Defaults to calling forward_native. Subclasses can override this for CPU-specific optimizations (though often forward_native is sufficient).

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for CPU.

forward_cuda(*args, **kwargs) Any[source]#

CUDA (NVIDIA GPU)-specific implementation of the operation.

Defaults to calling forward_gpu. Subclasses can override this for optimizations specific to the CUDA platform, if necessary.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for CUDA GPUs.

forward_gpu(*args, **kwargs) Any[source]#

Generic GPU-specific implementation of the operation.

Defaults to calling forward_native. This method serves as the base for CUDA and ROCm backends unless they are specifically overridden. Subclasses can override this for general GPU optimizations.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for GPUs.

abstract forward_native(*args, **kwargs) Any[source]#

The core, backend-agnostic implementation of the operation.

This method MUST be implemented by any concrete subclass of BaseOperation. It serves as the default implementation if no backend-specific override is available or applicable.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation. Type depends on the specific operation.

forward_rocm(*args, **kwargs) Any[source]#

ROCm (AMD GPU)-specific implementation of the operation.

Defaults to calling forward_gpu. Subclasses can override this for optimizations specific to the ROCm platform, if necessary.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for ROCm GPUs.

forward_tpu(*args, **kwargs) Any[source]#

TPU-specific implementation of the operation.

Defaults to calling forward_native. Subclasses can override this for TPU-specific optimizations.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for TPU.