Source code for easydel.layers.operations._base_operation

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Defines the base class for operations within EasyDeL that may have
backend-specific implementations (CPU, GPU, TPU).
"""

import functools
import typing as tp
from abc import ABC, abstractmethod

import jax
from eformer.loggings import get_logger

from easydel.utils.helpers import check_bool_flag

from ._operation_meta import OperationMetadata

logger = get_logger("EasyDeL-BaseOperation")


[docs]class BaseOperation(ABC): """ Abstract Base Class for defining operations with potential backend-specific implementations. This class provides a structure for defining a core operation (`forward_native`) and allowing optional, optimized implementations for different hardware backends supported by JAX (TPU, GPU - CUDA/ROCm, CPU). The `__call__` method acts as a dispatcher, detecting the current JAX default backend and executing the corresponding `forward_...` method. If a specific backend implementation (e.g., `forward_tpu`) is not overridden in a subclass, it defaults to calling `forward_native`. Subclasses MUST implement the `forward_native` method. They CAN optionally override `forward_tpu`, `forward_gpu`, `forward_cpu`, `forward_rocm`, or `forward_cuda` to provide backend-specific optimizations. """ metadata: OperationMetadata | None = None
[docs] @classmethod @abstractmethod def get_impl_name(cls) -> str | tuple[str, ...]: """ Returns the unique name(s) identifying this attention implementation. Used by the `OperationRegistry`. Can return a single string or a tuple/list of strings if the implementation has multiple aliases. Returns: A string or tuple/list of strings representing the implementation name(s). """
[docs] @abstractmethod def get_impl_metadata(self) -> OperationMetadata: """ Returns the `OperationMetadata` associated with this implementation instance. Returns: The `OperationMetadata` instance passed during initialization. """
[docs] def current_backend(self) -> tp.Literal["tpu", "gpu", "cpu"]: """ Returns the current JAX default backend as a lowercase string literal. Returns: "tpu", "gpu", or "cpu". """ return jax.default_backend()
[docs] @abstractmethod def forward_native(self, *args, **kwargs) -> tp.Any: """ The core, backend-agnostic implementation of the operation. This method MUST be implemented by any concrete subclass of `BaseOperation`. It serves as the default implementation if no backend-specific override is available or applicable. Args: *args: Positional arguments for the operation. **kwargs: Keyword arguments for the operation. Returns: The result of the operation. Type depends on the specific operation. """
[docs] def forward_tpu(self, *args, **kwargs) -> tp.Any: """ TPU-specific implementation of the operation. Defaults to calling `forward_native`. Subclasses can override this for TPU-specific optimizations. Args: *args: Positional arguments for the operation. **kwargs: Keyword arguments for the operation. Returns: The result of the operation, potentially optimized for TPU. """ return self.forward_native(*args, **kwargs)
[docs] def forward_tt(self, *args, **kwargs) -> tp.Any: """ TT-specific implementation of the operation. Defaults to calling `forward_native`. Subclasses can override this for TT-specific optimizations. Args: *args: Positional arguments for the operation. **kwargs: Keyword arguments for the operation. Returns: The result of the operation, potentially optimized for TT. """ return self.forward_native(*args, **kwargs)
[docs] def forward_cpu(self, *args, **kwargs) -> tp.Any: """ CPU-specific implementation of the operation. Defaults to calling `forward_native`. Subclasses can override this for CPU-specific optimizations (though often `forward_native` is sufficient). Args: *args: Positional arguments for the operation. **kwargs: Keyword arguments for the operation. Returns: The result of the operation, potentially optimized for CPU. """ return self.forward_native(*args, **kwargs)
[docs] def forward_gpu(self, *args, **kwargs) -> tp.Any: """ Generic GPU-specific implementation of the operation. Defaults to calling `forward_native`. This method serves as the base for CUDA and ROCm backends unless they are specifically overridden. Subclasses can override this for general GPU optimizations. Args: *args: Positional arguments for the operation. **kwargs: Keyword arguments for the operation. Returns: The result of the operation, potentially optimized for GPUs. """ return self.forward_native(*args, **kwargs)
[docs] def forward_rocm(self, *args, **kwargs) -> tp.Any: """ ROCm (AMD GPU)-specific implementation of the operation. Defaults to calling `forward_gpu`. Subclasses can override this for optimizations specific to the ROCm platform, if necessary. Args: *args: Positional arguments for the operation. **kwargs: Keyword arguments for the operation. Returns: The result of the operation, potentially optimized for ROCm GPUs. """ return self.forward_gpu(*args, **kwargs)
[docs] def forward_cuda(self, *args, **kwargs) -> tp.Any: """ CUDA (NVIDIA GPU)-specific implementation of the operation. Defaults to calling `forward_gpu`. Subclasses can override this for optimizations specific to the CUDA platform, if necessary. Args: *args: Positional arguments for the operation. **kwargs: Keyword arguments for the operation. Returns: The result of the operation, potentially optimized for CUDA GPUs. """ return self.forward_gpu(*args, **kwargs)
def __call__(self, *args, **kwargs) -> tp.Any: """ Executes the appropriate forward method based on the detected JAX backend. This method determines the current `jax.default_backend()` and dispatches the call to the corresponding `forward_...` method (e.g., `forward_tpu` if the backend is TPU). It logs which execution path is taken at the DEBUG level. If the backend is not explicitly recognized, it falls back to `forward_native`. Args: *args: Positional arguments to pass to the forward method. **kwargs: Keyword arguments to pass to the forward method. Returns: The result returned by the executed forward method. """ if check_bool_flag("FORCE_NATIVE_RUNTIME", False): return self.forward_native(*args, **kwargs) match self.metadata.backend: case self.EasyDeLBackends.TPU: logger.debug("Calling into TPU exec") return self.forward_tpu(*args, **kwargs) case self.EasyDeLBackends.GPU: logger.debug("Calling into GPU exec") return self.forward_gpu(*args, **kwargs) case self.EasyDeLBackends.TT: logger.debug("Calling into TT exec") return self.forward_tt(*args, **kwargs) case self.EasyDeLBackends.CPU: logger.debug("Calling into CPU exec") return self.forward_native(*args, **kwargs) case _: raise RuntimeError(f"unknown backend at OperationImpl! {self.metadata.backend}") @functools.cached_property def EasyDeLBackends(self): from easydel.infra.etils import EasyDeLBackends return EasyDeLBackends
_I = tp.TypeVar("ICa", bound=BaseOperation)
[docs]class OperationRegistry: """ Registry for discovering and managing different `OperationImpl` classes. Allows registering implementations using a decorator and retrieving or instantiating them by name. """ _registry: tp.ClassVar[dict[str, type[BaseOperation]]] = {}
[docs] @classmethod def register(cls, impl_cls: type[_I]) -> type[_I]: """ Class method decorator to register an `OperationImpl` subclass. The implementation is registered under the name(s) returned by its `get_impl_name()` class method. Example: ```python @OperationRegistry.register class FlashOperationImpl(OperationImpl): @classmethod def get_impl_name(cls) -> str: return "flash" # ... implementation ... ``` Args: impl_cls: The `OperationImpl` subclass to register. Returns: The registered class itself. """ impl_names_raw: str | tuple[str, ...] = impl_cls.get_impl_name() impl_names: list[str] | tuple[str, ...] if not isinstance(impl_names_raw, list | tuple): impl_names = [impl_names_raw] else: impl_names = impl_names_raw impl_name: str for impl_name in impl_names: already_registered: bool = impl_name in cls._registry if already_registered: logger.warning(f"Operation implementation '{impl_name}' already registered. Overwriting.") cls._registry[impl_name] = impl_cls logger.debug(f"Registered attention implementation: {impl_name}") return impl_cls
[docs] @classmethod def get(cls, impl_name: str) -> type[BaseOperation]: """ Retrieves an attention implementation class by its registered name. Args: impl_name: The name of the implementation to retrieve. Returns: The `OperationImpl` subclass registered under the given name. Raises: ValueError: If no implementation is registered with that name. """ is_registered: bool = impl_name in cls._registry if not is_registered: available_impls: list[str] = list(cls._registry.keys()) raise ValueError( f"Operation implementation '{impl_name}' not found. Available implementations: {available_impls}" ) impl_class: type[BaseOperation] = cls._registry[impl_name] return impl_class
[docs] @classmethod def create(cls, impl_name: str, metadata: OperationMetadata) -> BaseOperation: """ Creates an instance of an attention implementation by name. Retrieves the class associated with `impl_name` and initializes it with the provided `metadata`. Args: impl_name: The name of the implementation to instantiate. metadata: The `OperationMetadata` to pass to the implementation's constructor. Returns: An initialized instance of the requested `OperationImpl` subclass. Raises: ValueError: If no implementation is registered with `impl_name`. """ impl_cls: type[BaseOperation] = cls.get(impl_name) instance: BaseOperation = impl_cls(metadata) return instance
[docs] @classmethod def list_implementations(cls) -> list[str]: """ Returns a list of names of all registered attention implementations. Returns: A list of strings, where each string is a registered implementation name. """ return list(cls._registry.keys())