Source code for easydel.utils.compiling_utils

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

import functools
import hashlib
import os
import pickle
import typing as tp
import warnings

import jax
from jax.experimental.serialize_executable import deserialize_and_load, serialize

from .helpers import check_bool_flag, get_cache_dir

if tp.TYPE_CHECKING:
	from jax._src.stages import Compiled, Lowered
else:
	Compiled, Lowered = tp.Any, tp.Any

RECOMPILE_FORCE = check_bool_flag("RECOMPILE_FORCE", False)
ECACHE_COMPILES = check_bool_flag("ECACHE_COMPILES", True)

CACHE_DIR = get_cache_dir()
COMPILE_FUNC_DIR = CACHE_DIR / "compiled_funcs"
COMPILE_FUNC_DIR.mkdir(parents=True, exist_ok=True)
COMPILED_FILE_NAME = "compiled.func"

COMPILED_CACHE: tp.Dict[tp.Tuple, tp.Any] = {}


[docs]def is_jit_wrapped(fn): return all( [ hasattr(fn, "_fun"), hasattr(fn, "lower"), hasattr(fn, "eval_shape"), hasattr(fn, "trace"), ] )
[docs]def cjit( fn: tp.Callable, static_argnums: tp.Optional[tp.Tuple[str]] = None, static_argnames: tp.Optional[tp.Tuple[int]] = None, ): assert is_jit_wrapped(fn=fn), "function should be jit wrapped already" @functools.wraps(fn) def wrapped(*args, **kwargs): static_arg_indices = set(static_argnums) if static_argnums is not None else set() dynamic_args = tuple( arg for i, arg in enumerate(args) if i not in static_arg_indices ) dynamic_kwargs = kwargs.copy() if static_argnames is not None: for key in static_argnames: dynamic_kwargs.pop(key, None) signature = get_signature(dynamic_args, dynamic_kwargs) cache_key = (fn, signature) if cache_key in COMPILED_CACHE: compiled_func = COMPILED_CACHE[cache_key] return compiled_func(*dynamic_args, **dynamic_kwargs) lowered_func: Lowered = fn.lower(*args, **kwargs) compiled_func = smart_compile(lowered_func, "cached-jit") COMPILED_CACHE[cache_key] = compiled_func return compiled_func(*dynamic_args, **dynamic_kwargs) return wrapped
[docs]def hash_fn(self) -> int: shu = "".join( str(cu) for cu in self.__dict__.values() if isinstance(cu, (float, int, float, bool, dict, list)) ) return get_safe_hash_int(shu)
# @functools.lru_cache(maxsize=2048)
[docs]def get_safe_hash_int(text, algorithm="md5"): try: text_str = str(text) hash_object = getattr(hashlib, algorithm)(text_str.encode()) return int.from_bytes(hash_object.digest(), byteorder="big") except AttributeError as e: raise ValueError(f"Unsupported hash algorithm: {algorithm}") from e except Exception as e: raise Exception(f"Error generating hash: {str(e)}") from e
[docs]def get_signature(args, kwargs) -> tp.Tuple: """Get a hashable signature of args/kwargs shapes and dtypes.""" def get_array_signature(x): if hasattr(x, "shape") and hasattr(x, "dtype"): return (tuple(x.shape), str(x.dtype)) return str(type(x)) args_sig = tuple(get_array_signature(arg) for arg in args) kwargs_sig = tuple((k, get_array_signature(v)) for k, v in sorted(kwargs.items())) return (args_sig, kwargs_sig)
[docs]def get_hash_of_lowering(lowered_func: Lowered): text_representation = lowered_func.as_text() hash_object = hashlib.sha256(text_representation.encode("utf-8")) hash_digest = hash_object.hexdigest() return hash_digest
[docs]def smart_compile( lowered_func: Lowered, tag: tp.Optional[str] = None, ) -> Compiled: func_hash = get_hash_of_lowering(lowered_func) foldername = str(func_hash) if tag is None else f"{tag}-{func_hash}" func_dir = COMPILE_FUNC_DIR / foldername filepath = func_dir / COMPILED_FILE_NAME post_fix = f" (TAG : {tag})" if tag else "" if filepath.exists() and not RECOMPILE_FORCE: try: (serialized, in_tree, out_tree) = pickle.load(open(filepath, "rb")) compiled_func = deserialize_and_load( serialized=serialized, in_tree=in_tree, out_tree=out_tree, ) return compiled_func except Exception as e: warnings.warn( f"couldn't load compiled function due to {e}" + post_fix, stacklevel=4, ) compiled_func: Compiled = lowered_func.compile() if ECACHE_COMPILES: serialized, in_tree, out_tree = serialize(compiled_func) func_dir.mkdir(parents=True, exist_ok=True) try: pickle.dump((serialized, in_tree, out_tree), open(filepath, "wb")) except Exception as e: warnings.warn( f"couldn't save compiled function due to {e}" + post_fix, stacklevel=4, ) return compiled_func else: compiled_func: Compiled = lowered_func.compile() if ECACHE_COMPILES: try: serialized, in_tree, out_tree = serialize(compiled_func) func_dir.mkdir(parents=True, exist_ok=True) pickle.dump((serialized, in_tree, out_tree), open(filepath, "wb")) except Exception as e: warnings.warn( f"couldn't save and serialize compiled function due to {e}" + post_fix, stacklevel=4, ) return compiled_func
[docs]def save_compiled_fn( path: tp.Union[str, os.PathLike], fn: Compiled, prefix: tp.Optional[str] = None, ): path.mkdir(parents=True, exist_ok=True) prefix = prefix or "" filename = path / (prefix + "-" + COMPILED_FILE_NAME) serialized, in_tree, out_tree = serialize(fn) try: pickle.dump((serialized, in_tree, out_tree), open(filename, "wb")) except Exception as e: # noqa warnings.warn(f"couldn't save compiled function due to {e}", stacklevel=4)
[docs]def load_compiled_fn( path: tp.Union[str, os.PathLike], prefix: tp.Optional[str] = None, ): prefix = prefix or "" filename = path / (prefix + "-" + COMPILED_FILE_NAME) (serialized, in_tree, out_tree) = pickle.load(open(filename, "rb")) return deserialize_and_load( serialized=serialized, in_tree=in_tree, out_tree=out_tree, )
[docs]def cache_compiles( tag: tp.Optional[str] = None, static_argnames: tp.Optional[tp.List[str]] = None, ): static_argnames = static_argnames or [] def create_wrapper(func: tp.Callable, tag: tp.Optional[str] = None) -> tp.Callable: original_func = getattr(func, "_fun", func) func_id = str( hashlib.sha256( original_func.__code__.co_code, ) .hexdigest() .encode("utf-8") ) @functools.wraps(func) def wrapper(*args, **kwargs): signature = (func_id, get_signature(args, kwargs)) if signature in COMPILED_CACHE: for static_key in static_argnames: kwargs.pop(static_key) return COMPILED_CACHE[signature](*args, **kwargs) if hasattr(func, "lower"): lowered = func.lower(*args, **kwargs) for static_key in static_argnames: kwargs.pop(static_key) func_hash = get_hash_of_lowering(lowered) sig_hash = hashlib.sha256(str(signature).encode()).hexdigest()[:8] foldername = ( f"{tag}-{func_hash}-{sig_hash}" if tag else f"{func_hash}-{sig_hash}" ) func_dir = COMPILE_FUNC_DIR / foldername filepath = func_dir / "compiled.func" if filepath.exists() and not RECOMPILE_FORCE: with open(filepath, "rb") as f: serialized, in_tree, out_tree = pickle.load(f) compiled_func = deserialize_and_load( serialized=serialized, in_tree=in_tree, out_tree=out_tree, ) COMPILED_CACHE[signature] = compiled_func return compiled_func(*args, **kwargs) compiled_func = lowered.compile() COMPILED_CACHE[signature] = compiled_func try: serialized, in_tree, out_tree = serialize(compiled_func) func_dir.mkdir(parents=True, exist_ok=True) with open(filepath, "wb") as f: pickle.dump((serialized, in_tree, out_tree), f) except Exception as e: print(f"Failed to cache compilation: {e}") return compiled_func(*args, **kwargs) return func(*args, **kwargs) wrapper._COMPILED_CACHE = COMPILED_CACHE return wrapper def decorator(func: tp.Callable) -> tp.Callable: return create_wrapper(func, tag) return decorator
[docs]def lower_function( func, func_input_args, func_input_kwargs, mesh=None, in_shardings=None, out_shardings=None, static_argnums=None, donate_argnums=None, ): """ lower a JAX function with optional sharding and mesh configuration. Args: func: The JAX function to compile. func_input_args: Input arguments for the function. func_input_kwargs: Input keyword arguments for the function. mesh: tp.Optional JAX mesh for distributed execution. in_shardings: tp.Optional input sharding specifications. out_shardings: tp.Optional output sharding specifications. static_argnums: Indices of static arguments. donate_argnums: Indices of arguments to donate. Returns: lowered JAX function. """ if mesh is None: return jax.jit( func, in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, donate_argnums=donate_argnums, ).lower(*func_input_args, **func_input_kwargs) with mesh: return jax.jit( func, in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, donate_argnums=donate_argnums, ).lower(*func_input_args, **func_input_kwargs)
[docs]def compile_function( func, func_input_args, func_input_kwargs, mesh=None, in_shardings=None, out_shardings=None, static_argnums=None, donate_argnums=None, ): """ Compiles a JAX function with optional sharding and mesh configuration. Args: func: The JAX function to compile. func_input_args: Input arguments for the function. func_input_kwargs: Input keyword arguments for the function. mesh: tp.Optional JAX mesh for distributed execution. in_shardings: tp.Optional input sharding specifications. out_shardings: tp.Optional output sharding specifications. static_argnums: Indices of static arguments. donate_argnums: Indices of arguments to donate. Returns: Compiled JAX function. """ return lower_function( func, func_input_args, func_input_kwargs, mesh=mesh, in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, donate_argnums=donate_argnums, ).compile()
if __name__ == "__main__": jnp = jax.numpy @cjit @jax.jit def my_function(x, y): return x * y + x a = jnp.array([1, 2, 3], dtype=jnp.float32) b = jnp.array([4, 5, 6], dtype=jnp.float32) result1 = my_function(a, b) # Compiles and caches on first call result2 = my_function(a, b) # Returns cached result c = jnp.array([1, 2, 3], dtype=jnp.float32) d = jnp.array([1, 1, 1], dtype=jnp.float32) result3 = my_function(c, d) print(result1, result2, result3)