easydel.utils.compiling_utils#

easydel.utils.compiling_utils.cache_compiles(tag: Optional[str] = None, static_argnames: Optional[List[str]] = None)[source]#
easydel.utils.compiling_utils.cjit(fn: Callable, static_argnums: Optional[Tuple[str]] = None, static_argnames: Optional[Tuple[int]] = None)[source]#
easydel.utils.compiling_utils.compile_function(func, func_input_args, func_input_kwargs, mesh=None, in_shardings=None, out_shardings=None, static_argnums=None, donate_argnums=None)[source]#

Compiles a JAX function with optional sharding and mesh configuration.

Parameters
  • func – The JAX function to compile.

  • func_input_args – Input arguments for the function.

  • func_input_kwargs – Input keyword arguments for the function.

  • mesh – tp.Optional JAX mesh for distributed execution.

  • in_shardings – tp.Optional input sharding specifications.

  • out_shardings – tp.Optional output sharding specifications.

  • static_argnums – Indices of static arguments.

  • donate_argnums – Indices of arguments to donate.

Returns

Compiled JAX function.

easydel.utils.compiling_utils.get_hash_of_lowering(lowered_func: Any)[source]#
easydel.utils.compiling_utils.get_safe_hash_int(text, algorithm='md5')[source]#
easydel.utils.compiling_utils.get_signature(args, kwargs) Tuple[source]#

Get a hashable signature of args/kwargs shapes and dtypes.

easydel.utils.compiling_utils.hash_fn(self) int[source]#
easydel.utils.compiling_utils.is_jit_wrapped(fn)[source]#
easydel.utils.compiling_utils.load_compiled_fn(path: Union[str, PathLike], prefix: Optional[str] = None)[source]#
easydel.utils.compiling_utils.lower_function(func, func_input_args, func_input_kwargs, mesh=None, in_shardings=None, out_shardings=None, static_argnums=None, donate_argnums=None)[source]#

lower a JAX function with optional sharding and mesh configuration.

Parameters
  • func – The JAX function to compile.

  • func_input_args – Input arguments for the function.

  • func_input_kwargs – Input keyword arguments for the function.

  • mesh – tp.Optional JAX mesh for distributed execution.

  • in_shardings – tp.Optional input sharding specifications.

  • out_shardings – tp.Optional output sharding specifications.

  • static_argnums – Indices of static arguments.

  • donate_argnums – Indices of arguments to donate.

Returns

lowered JAX function.

easydel.utils.compiling_utils.save_compiled_fn(path: Union[str, PathLike], fn: Any, prefix: Optional[str] = None)[source]#
easydel.utils.compiling_utils.smart_compile(lowered_func: Any, tag: Optional[str] = None) Any[source]#