Source code for easydel.layers.ops._base_operation

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Defines the base class for operations within EasyDeL that may have
backend-specific implementations (CPU, GPU, TPU).
"""

import typing as tp
from abc import ABC, abstractmethod

import jax

from easydel.infra.etils import EasyDeLBackends
from easydel.utils.helpers import get_logger

logger = get_logger("EasyDeL-BaseOperation")


[docs]class BaseOperation(ABC): """ Abstract Base Class for defining operations with potential backend-specific implementations. This class provides a structure for defining a core operation (`forward_native`) and allowing optional, optimized implementations for different hardware backends supported by JAX (TPU, GPU - CUDA/ROCm, CPU). The `__call__` method acts as a dispatcher, detecting the current JAX default backend and executing the corresponding `forward_...` method. If a specific backend implementation (e.g., `forward_tpu`) is not overridden in a subclass, it defaults to calling `forward_native`. Subclasses MUST implement the `forward_native` method. They CAN optionally override `forward_tpu`, `forward_gpu`, `forward_cpu`, `forward_rocm`, or `forward_cuda` to provide backend-specific optimizations. """
[docs] @abstractmethod def forward_native(self, *args, **kwargs) -> tp.Any: """ The core, backend-agnostic implementation of the operation. This method MUST be implemented by any concrete subclass of `BaseOperation`. It serves as the default implementation if no backend-specific override is available or applicable. Args: *args: Positional arguments for the operation. **kwargs: Keyword arguments for the operation. Returns: The result of the operation. Type depends on the specific operation. """
[docs] def forward_tpu(self, *args, **kwargs) -> tp.Any: """ TPU-specific implementation of the operation. Defaults to calling `forward_native`. Subclasses can override this for TPU-specific optimizations. Args: *args: Positional arguments for the operation. **kwargs: Keyword arguments for the operation. Returns: The result of the operation, potentially optimized for TPU. """ return self.forward_native(*args, **kwargs)
[docs] def forward_cpu(self, *args, **kwargs) -> tp.Any: """ CPU-specific implementation of the operation. Defaults to calling `forward_native`. Subclasses can override this for CPU-specific optimizations (though often `forward_native` is sufficient). Args: *args: Positional arguments for the operation. **kwargs: Keyword arguments for the operation. Returns: The result of the operation, potentially optimized for CPU. """ return self.forward_native(*args, **kwargs)
[docs] def forward_gpu(self, *args, **kwargs) -> tp.Any: """ Generic GPU-specific implementation of the operation. Defaults to calling `forward_native`. This method serves as the base for CUDA and ROCm backends unless they are specifically overridden. Subclasses can override this for general GPU optimizations. Args: *args: Positional arguments for the operation. **kwargs: Keyword arguments for the operation. Returns: The result of the operation, potentially optimized for GPUs. """ return self.forward_native(*args, **kwargs)
[docs] def forward_rocm(self, *args, **kwargs) -> tp.Any: """ ROCm (AMD GPU)-specific implementation of the operation. Defaults to calling `forward_gpu`. Subclasses can override this for optimizations specific to the ROCm platform, if necessary. Args: *args: Positional arguments for the operation. **kwargs: Keyword arguments for the operation. Returns: The result of the operation, potentially optimized for ROCm GPUs. """ return self.forward_gpu(*args, **kwargs)
[docs] def forward_cuda(self, *args, **kwargs) -> tp.Any: """ CUDA (NVIDIA GPU)-specific implementation of the operation. Defaults to calling `forward_gpu`. Subclasses can override this for optimizations specific to the CUDA platform, if necessary. Args: *args: Positional arguments for the operation. **kwargs: Keyword arguments for the operation. Returns: The result of the operation, potentially optimized for CUDA GPUs. """ return self.forward_gpu(*args, **kwargs)
def __call__(self, *args, **kwargs) -> tp.Any: """ Executes the appropriate forward method based on the detected JAX backend. This method determines the current `jax.default_backend()` and dispatches the call to the corresponding `forward_...` method (e.g., `forward_tpu` if the backend is TPU). It logs which execution path is taken at the DEBUG level. If the backend is not explicitly recognized, it falls back to `forward_native`. Args: *args: Positional arguments to pass to the forward method. **kwargs: Keyword arguments to pass to the forward method. Returns: The result returned by the executed forward method. """ backend = jax.default_backend() match backend: case EasyDeLBackends.TPU: logger.debug("Calling into TPU execution path.") return self.forward_tpu(*args, **kwargs) case EasyDeLBackends.GPU: # Note: JAX identifies GPU generically. We rely on subclasses # potentially overriding forward_cuda/forward_rocm if needed, # but the primary dispatch here is to forward_gpu. logger.debug("Calling into GPU execution path.") return self.forward_gpu(*args, **kwargs) case EasyDeLBackends.CPU: logger.debug("Calling into CPU execution path.") # CPU often uses the native implementation directly return self.forward_cpu(*args, **kwargs) case _: # Fallback for unknown or non-standard backends logger.debug( f"Calling into Native execution path (Unknown or default backend: {backend})." ) return self.forward_native(*args, **kwargs)