easydel.layers.ops.__init__#

class easydel.layers.ops.__init__.BaseOperation[source]#

Bases: ABC

Abstract Base Class for defining operations with potential backend-specific implementations.

This class provides a structure for defining a core operation (forward_native) and allowing optional, optimized implementations for different hardware backends supported by JAX (TPU, GPU - CUDA/ROCm, CPU).

The __call__ method acts as a dispatcher, detecting the current JAX default backend and executing the corresponding forward_… method. If a specific backend implementation (e.g., forward_tpu) is not overridden in a subclass, it defaults to calling forward_native.

Subclasses MUST implement the forward_native method. They CAN optionally override forward_tpu, forward_gpu, forward_cpu, forward_rocm, or forward_cuda to provide backend-specific optimizations.

forward_cpu(*args, **kwargs) Any[source]#

CPU-specific implementation of the operation.

Defaults to calling forward_native. Subclasses can override this for CPU-specific optimizations (though often forward_native is sufficient).

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for CPU.

forward_cuda(*args, **kwargs) Any[source]#

CUDA (NVIDIA GPU)-specific implementation of the operation.

Defaults to calling forward_gpu. Subclasses can override this for optimizations specific to the CUDA platform, if necessary.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for CUDA GPUs.

forward_gpu(*args, **kwargs) Any[source]#

Generic GPU-specific implementation of the operation.

Defaults to calling forward_native. This method serves as the base for CUDA and ROCm backends unless they are specifically overridden. Subclasses can override this for general GPU optimizations.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for GPUs.

abstract forward_native(*args, **kwargs) Any[source]#

The core, backend-agnostic implementation of the operation.

This method MUST be implemented by any concrete subclass of BaseOperation. It serves as the default implementation if no backend-specific override is available or applicable.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation. Type depends on the specific operation.

forward_rocm(*args, **kwargs) Any[source]#

ROCm (AMD GPU)-specific implementation of the operation.

Defaults to calling forward_gpu. Subclasses can override this for optimizations specific to the ROCm platform, if necessary.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for ROCm GPUs.

forward_tpu(*args, **kwargs) Any[source]#

TPU-specific implementation of the operation.

Defaults to calling forward_native. Subclasses can override this for TPU-specific optimizations.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for TPU.

class easydel.layers.ops.__init__.RecurrentGLA[source]#

Bases: BaseOperation

forward_native(query: ~jax.Array, key: ~jax.Array, value: ~jax.Array, gk: ~jax.Array, scale: float = -1.0, initial_state: ~typing.Optional[~jax.Array] = None, chunk_size: int = 0, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, output_final_state: bool = False) Tuple[Array, Optional[Array]][source]#

The core, backend-agnostic implementation of the operation.

This method MUST be implemented by any concrete subclass of BaseOperation. It serves as the default implementation if no backend-specific override is available or applicable.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation. Type depends on the specific operation.