easydel.utils.compiling_utils#

Compilation utilities for JAX function optimization.

Provides enhanced JIT compilation with persistent caching to disk, reducing compilation overhead across script runs.

Functions:

ejit: Enhanced JIT with persistent caching save_compiled_fn: Save compiled function to disk load_compiled_fn: Load compiled function from disk load_cached_functions: Load multiple cached functions smart_compile: Smart compilation with auto-caching hash_fn: Generate hash for function signature

Constants:

RECOMPILE_FORCE: Force recompilation flag ECACHE_COMPILES: Enable compilation caching CACHE_DIR: Cache directory path COMPILE_FUNC_DIR: Compiled functions directory COMPILED_CACHE: In-memory cache of compiled functions

Key Features:
  • Persistent disk caching of compiled functions

  • Automatic cache invalidation on changes

  • Hardware-specific signatures

  • Two-level caching (memory + disk)

  • Graceful fallback on errors

Example

>>> from easydel.utils.compiling_utils import ejit
>>>
>>> @ejit
... def optimized_fn(x, y):
...     return x @ y + x.T @ y.T
>>>
>>> # First call compiles and caches
>>> result = optimized_fn(a, b)
>>> # Next run loads from cache
>>> result = optimized_fn(a, b)
class easydel.utils.compiling_utils.NoCompileContext(message: str = 'JAX attempted to compile a new executable inside ForbidCompile.')[source]#

Bases: object

Context manager that fails if JAX triggers a new compilation.

Useful around hot paths that are expected to hit cached executables only.

easydel.utils.compiling_utils.get_hash_of_lowering(lowered_func: Lowered)[source]#
easydel.utils.compiling_utils.get_safe_hash_int(text, algorithm='md5')[source]#

Generate a hash of text using specified algorithm with safety checks.

easydel.utils.compiling_utils.hash_fn(self) int[source]#

Generate a hash for an object based on its dictionary values.

easydel.utils.compiling_utils.load_cached_functions(verbose: bool = True) None[source]#

Pre-loads all valid cached functions from disk into the persistent L2 cache.

easydel.utils.compiling_utils.load_compiled_fn(path: str | os.PathLike, prefix: str | None = None)[source]#

Load a compiled function from disk.

easydel.utils.compiling_utils.save_compiled_fn(path: str | os.PathLike, fn: Compiled, prefix: str | None = None)[source]#

Save a compiled JAX function to disk for later reuse.

Serializes a compiled function along with its input/output tree structures, allowing it to be loaded and executed in future Python sessions.

Parameters
  • path – Directory path where the compiled function will be saved. Will be created if it doesn’t exist.

  • fn – Compiled JAX function (output of lowered.compile()).

  • prefix – Optional prefix for the filename. Useful for organizing multiple compiled functions in the same directory.

Files Created:
  • {prefix}-compiled.executable: Serialized function and metadata

Example

>>> # Compile a function
>>> jitted = jax.jit(my_function)
>>> lowered = jitted.lower(sample_input)
>>> compiled = lowered.compile()
>>>
>>> # Save to disk
>>> from pathlib import Path
>>> cache_dir = Path("./my_cache")
>>> save_compiled_fn(cache_dir, compiled, prefix="model_v1")
>>>
>>> # File created: ./my_cache/model_v1-compiled.executable
Raises

Warning – If serialization fails (logged, not raised).

Notes

  • Compiled functions are hardware-specific

  • Large models may produce large cache files

  • Uses pickle for serialization (standard security caveats apply)

easydel.utils.compiling_utils.smart_compile(lowered_func: Lowered, tag: str | None = None, verbose: bool = True, cache_key: tuple[str, tuple] | None = None) tuple[Compiled, tuple[str, tuple] | None][source]#

Compile a lowered JAX function with caching.