easydel.utils.compiling_utils#
Compilation utilities for JAX function optimization.
Provides enhanced JIT compilation with persistent caching to disk, reducing compilation overhead across script runs.
- Functions:
ejit: Enhanced JIT with persistent caching save_compiled_fn: Save compiled function to disk load_compiled_fn: Load compiled function from disk load_cached_functions: Load multiple cached functions smart_compile: Smart compilation with auto-caching hash_fn: Generate hash for function signature
- Constants:
RECOMPILE_FORCE: Force recompilation flag ECACHE_COMPILES: Enable compilation caching CACHE_DIR: Cache directory path COMPILE_FUNC_DIR: Compiled functions directory COMPILED_CACHE: In-memory cache of compiled functions
- Key Features:
Persistent disk caching of compiled functions
Automatic cache invalidation on changes
Hardware-specific signatures
Two-level caching (memory + disk)
Graceful fallback on errors
Example
>>> from easydel.utils.compiling_utils import ejit
>>>
>>> @ejit
... def optimized_fn(x, y):
... return x @ y + x.T @ y.T
>>>
>>> # First call compiles and caches
>>> result = optimized_fn(a, b)
>>> # Next run loads from cache
>>> result = optimized_fn(a, b)
- class easydel.utils.compiling_utils.NoCompileContext(message: str = 'JAX attempted to compile a new executable inside ForbidCompile.')[source]#
Bases:
objectContext manager that fails if JAX triggers a new compilation.
Useful around hot paths that are expected to hit cached executables only.
- easydel.utils.compiling_utils.get_safe_hash_int(text, algorithm='md5')[source]#
Generate a hash of text using specified algorithm with safety checks.
- easydel.utils.compiling_utils.hash_fn(self) int[source]#
Generate a hash for an object based on its dictionary values.
- easydel.utils.compiling_utils.load_cached_functions(verbose: bool = True) None[source]#
Pre-loads all valid cached functions from disk into the persistent L2 cache.
- easydel.utils.compiling_utils.load_compiled_fn(path: str | os.PathLike, prefix: str | None = None)[source]#
Load a compiled function from disk.
- easydel.utils.compiling_utils.save_compiled_fn(path: str | os.PathLike, fn: Compiled, prefix: str | None = None)[source]#
Save a compiled JAX function to disk for later reuse.
Serializes a compiled function along with its input/output tree structures, allowing it to be loaded and executed in future Python sessions.
- Parameters
path – Directory path where the compiled function will be saved. Will be created if it doesn’t exist.
fn – Compiled JAX function (output of lowered.compile()).
prefix – Optional prefix for the filename. Useful for organizing multiple compiled functions in the same directory.
- Files Created:
{prefix}-compiled.executable: Serialized function and metadata
Example
>>> # Compile a function >>> jitted = jax.jit(my_function) >>> lowered = jitted.lower(sample_input) >>> compiled = lowered.compile() >>> >>> # Save to disk >>> from pathlib import Path >>> cache_dir = Path("./my_cache") >>> save_compiled_fn(cache_dir, compiled, prefix="model_v1") >>> >>> # File created: ./my_cache/model_v1-compiled.executable
- Raises
Warning – If serialization fails (logged, not raised).
Notes
Compiled functions are hardware-specific
Large models may produce large cache files
Uses pickle for serialization (standard security caveats apply)