Source code for easydel.utils.compiling_utils

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Compilation utilities for JAX function optimization.

Provides enhanced JIT compilation with persistent caching to disk,
reducing compilation overhead across script runs.

Functions:
    ejit: Enhanced JIT with persistent caching
    save_compiled_fn: Save compiled function to disk
    load_compiled_fn: Load compiled function from disk
    load_cached_functions: Load multiple cached functions
    smart_compile: Smart compilation with auto-caching
    hash_fn: Generate hash for function signature

Constants:
    RECOMPILE_FORCE: Force recompilation flag
    ECACHE_COMPILES: Enable compilation caching
    CACHE_DIR: Cache directory path
    COMPILE_FUNC_DIR: Compiled functions directory
    COMPILED_CACHE: In-memory cache of compiled functions

Key Features:
    - Persistent disk caching of compiled functions
    - Automatic cache invalidation on changes
    - Hardware-specific signatures
    - Two-level caching (memory + disk)
    - Graceful fallback on errors

Example:
    >>> from easydel.utils.compiling_utils import ejit
    >>>
    >>> @ejit
    ... def optimized_fn(x, y):
    ...     return x @ y + x.T @ y.T
    >>>
    >>> # First call compiles and caches
    >>> result = optimized_fn(a, b)
    >>> # Next run loads from cache
    >>> result = optimized_fn(a, b)
"""

from __future__ import annotations

import hashlib
import os
import pickle
import typing as tp
import warnings
from functools import wraps

import jax
import numpy as np
from ejkernel.callib import ejit
from jax._src.interpreters import pxla
from jax.experimental.serialize_executable import deserialize_and_load, serialize

from .helpers import check_bool_flag, get_cache_dir

if tp.TYPE_CHECKING:
    from jax._src.stages import Compiled, Lowered

ejit = ejit
P = tp.ParamSpec("P")
R = tp.TypeVar("R")

RECOMPILE_FORCE = check_bool_flag("EASYDEL_RECOMPILE_FORCE", False)
ECACHE_COMPILES = check_bool_flag("EASYDEL_CACHE_COMPILES", False)

CACHE_DIR = get_cache_dir()

COMPILE_FUNC_DIR = os.getenv("COMPILE_FUNC_DIR", CACHE_DIR / "ejit_compiled_functions")
if not isinstance(COMPILE_FUNC_DIR, str):
    COMPILE_FUNC_DIR.mkdir(parents=True, exist_ok=True)

COMPILED_FILE_NAME = "compiled.executable"
SIGNATURE_FILE_NAME = "compiled.signature"
COMPILED_CACHE: dict[str, Compiled] = {}


def _get_hardware_signature() -> str:
    """Create signature for current JAX hardware environment.

    Returns:
        String representation of available JAX devices.
    """
    return str(jax.devices())


def _get_leaf_signature(leaf: tp.Any) -> tp.Hashable:
    """Generate hashable signature for PyTree leaf.

    Args:
        leaf: Leaf node from PyTree.

    Returns:
        Hashable signature including shape, dtype, and sharding.
    """
    if isinstance(leaf, jax.Array | np.ndarray):
        if hasattr(leaf, "sharding"):
            return (leaf.shape, str(jax.dtypes.canonicalize_dtype(leaf.dtype)), repr(leaf.sharding))
        return (leaf.shape, str(jax.dtypes.canonicalize_dtype(leaf.dtype)))
    return type(leaf)


def _get_args_signature(args: tuple, kwargs: dict) -> str:
    """Create signature for function arguments.

    Generates a unique signature based on the PyTree structure,
    shapes, and dtypes of arguments.

    Args:
        args: Positional arguments.
        kwargs: Keyword arguments.

    Returns:
        String signature of the arguments.
    """
    arg_leaves, arg_tree = jax.tree_util.tree_flatten((args, kwargs))
    leaf_signatures = tuple(map(_get_leaf_signature, arg_leaves))
    return str((arg_tree, leaf_signatures))


[docs]def load_cached_functions(verbose: bool = True) -> None: """Pre-loads all valid cached functions from disk into the persistent L2 cache.""" if not COMPILE_FUNC_DIR.exists(): return loaded_count = 0 for cache_key_dir in COMPILE_FUNC_DIR.iterdir(): if not cache_key_dir.is_dir(): continue cache_key = cache_key_dir.name filepath = cache_key_dir / COMPILED_FILE_NAME if filepath.exists(): try: with open(filepath, "rb") as f: serialized, in_tree, out_tree = pickle.load(f) compiled_func = deserialize_and_load(serialized, in_tree, out_tree) COMPILED_CACHE[cache_key] = compiled_func loaded_count += 1 except Exception as e: if verbose: warnings.warn(f"Could not pre-load ejit cache for key {cache_key}. Error: {e}", stacklevel=2) if verbose and loaded_count > 0: print(f"Pre-loaded {loaded_count} functions into ejit's persistent memory cache.")
[docs]def save_compiled_fn(path: str | os.PathLike, fn: Compiled, prefix: str | None = None): """Save a compiled JAX function to disk for later reuse. Serializes a compiled function along with its input/output tree structures, allowing it to be loaded and executed in future Python sessions. Args: path: Directory path where the compiled function will be saved. Will be created if it doesn't exist. fn: Compiled JAX function (output of lowered.compile()). prefix: Optional prefix for the filename. Useful for organizing multiple compiled functions in the same directory. Files Created: - {prefix}-compiled.executable: Serialized function and metadata Example: >>> # Compile a function >>> jitted = jax.jit(my_function) >>> lowered = jitted.lower(sample_input) >>> compiled = lowered.compile() >>> >>> # Save to disk >>> from pathlib import Path >>> cache_dir = Path("./my_cache") >>> save_compiled_fn(cache_dir, compiled, prefix="model_v1") >>> >>> # File created: ./my_cache/model_v1-compiled.executable Raises: Warning: If serialization fails (logged, not raised). Notes: - Compiled functions are hardware-specific - Large models may produce large cache files - Uses pickle for serialization (standard security caveats apply) """ from pathlib import Path path = Path(path) path.mkdir(parents=True, exist_ok=True) prefix = prefix or "" filename = path / (prefix + "-" + COMPILED_FILE_NAME if prefix else COMPILED_FILE_NAME) serialized, in_tree, out_tree = serialize(fn) try: with open(filename, "wb") as f: pickle.dump((serialized, in_tree, out_tree), f) except Exception as e: warnings.warn(f"Could not save compiled function to {filename}: {e}", stacklevel=2)
[docs]def load_compiled_fn(path: str | os.PathLike, prefix: str | None = None): """Load a compiled function from disk.""" prefix = prefix or "" filename = path / (prefix + "-" + COMPILED_FILE_NAME) (serialized, in_tree, out_tree) = pickle.load(open(filename, "rb")) return deserialize_and_load( serialized=serialized, in_tree=in_tree, out_tree=out_tree, )
[docs]def hash_fn(self) -> int: """Generate a hash for an object based on its dictionary values.""" shu = "".join(str(cu) for cu in self.__dict__.values() if isinstance(cu, float | int | bool | dict | list)) return get_safe_hash_int(shu)
[docs]def get_safe_hash_int(text, algorithm="md5"): """Generate a hash of text using specified algorithm with safety checks.""" try: text_str = str(text) hash_object = getattr(hashlib, algorithm)(text_str.encode()) return int.from_bytes(hash_object.digest(), byteorder="big") except AttributeError as e: raise ValueError(f"Unsupported hash algorithm: {algorithm}") from e except Exception as e: raise Exception(f"Error generating hash: {e!s}") from e
[docs]def get_hash_of_lowering(lowered_func: Lowered): text_representation = lowered_func.as_text() hash_object = hashlib.sha256(text_representation.encode("utf-8")) hash_digest = hash_object.hexdigest() return hash_digest
[docs]def smart_compile( lowered_func: Lowered, tag: str | None = None, verbose: bool = True, cache_key: tuple[str, tuple] | None = None, ) -> tuple[Compiled, tuple[str, tuple] | None]: """Compile a lowered JAX function with caching.""" func_hash = get_hash_of_lowering(lowered_func) foldername = str(func_hash) if tag is None else f"{tag}-{func_hash}" func_dir = COMPILE_FUNC_DIR / foldername filepath = func_dir / COMPILED_FILE_NAME signature_filepath = func_dir / SIGNATURE_FILE_NAME post_fix = f" (TAG : {tag})" if tag else "" signature = cache_key if filepath.exists() and not RECOMPILE_FORCE: try: (serialized, in_tree, out_tree) = pickle.load(open(filepath, "rb")) signature = pickle.load(open(signature_filepath, "rb")) compiled_func = deserialize_and_load( serialized=serialized, in_tree=in_tree, out_tree=out_tree, ) return compiled_func, signature except Exception as e: if verbose: warnings.warn( f"couldn't load compiled function due to {e}" + post_fix, stacklevel=4, ) compiled_func: Compiled = lowered_func.compile() if ECACHE_COMPILES: serialized, in_tree, out_tree = serialize(compiled_func) func_dir.mkdir(parents=True, exist_ok=True) try: pickle.dump((serialized, in_tree, out_tree), open(filepath, "wb")) pickle.dump(cache_key, open(signature_filepath, "wb")) except Exception as e: if verbose: warnings.warn( f"couldn't save compiled function due to {e}" + post_fix, stacklevel=4, ) return compiled_func, signature else: compiled_func: Compiled = lowered_func.compile() if ECACHE_COMPILES: try: serialized, in_tree, out_tree = serialize(compiled_func) func_dir.mkdir(parents=True, exist_ok=True) pickle.dump((serialized, in_tree, out_tree), open(filepath, "wb")) pickle.dump(cache_key, open(signature_filepath, "wb")) except Exception as e: if verbose: warnings.warn( f"couldn't save and serialize compiled function due to {e}" + post_fix, stacklevel=4, ) return compiled_func, signature
[docs]class NoCompileContext: """Context manager that fails if JAX triggers a new compilation. Useful around hot paths that are expected to hit cached executables only. """ def __init__(self, message: str = "JAX attempted to compile a new executable inside ForbidCompile."): """Initialize the guard with a custom failure message.""" self.message = message self._original_func = None def __enter__(self): """Patch JAX's cached lowering to detect compilation cache misses.""" # Store the original function self._original_func = pxla._cached_lowering_to_hlo original_cached_func = self._original_func @wraps(original_cached_func) def wrapper(*args, **kwargs): info_before = original_cached_func.cache_info() misses_before = info_before.misses result = original_cached_func(*args, **kwargs) info_after = original_cached_func.cache_info() misses_after = info_after.misses if misses_after > misses_before: raise RuntimeError(self.message) return result pxla._cached_lowering_to_hlo = wrapper def __exit__(self, exc_type, exc_value, traceback): """Restore the cached lowering function.""" if self._original_func: pxla._cached_lowering_to_hlo = self._original_func return False