easydel.utils.compiling_utils#
- easydel.utils.compiling_utils.cache_compiles(tag: Optional[str] = None, static_argnames: Optional[List[str]] = None)[source]#
- easydel.utils.compiling_utils.cjit(fn: Callable, static_argnums: Optional[Tuple[str]] = None, static_argnames: Optional[Tuple[int]] = None)[source]#
- easydel.utils.compiling_utils.compile_function(func, func_input_args, func_input_kwargs, mesh=None, in_shardings=None, out_shardings=None, static_argnums=None, donate_argnums=None)[source]#
Compiles a JAX function with optional sharding and mesh configuration.
- Parameters
func – The JAX function to compile.
func_input_args – Input arguments for the function.
func_input_kwargs – Input keyword arguments for the function.
mesh – tp.Optional JAX mesh for distributed execution.
in_shardings – tp.Optional input sharding specifications.
out_shardings – tp.Optional output sharding specifications.
static_argnums – Indices of static arguments.
donate_argnums – Indices of arguments to donate.
- Returns
Compiled JAX function.
- easydel.utils.compiling_utils.get_signature_tree_util(args: Tuple[Any, ...], kwargs: Dict[str, Any]) Tuple[source]#
- easydel.utils.compiling_utils.load_compiled_fn(path: Union[str, PathLike], prefix: Optional[str] = None)[source]#
- easydel.utils.compiling_utils.lower_function(func, func_input_args, func_input_kwargs, mesh=None, in_shardings=None, out_shardings=None, static_argnums=None, donate_argnums=None)[source]#
lower a JAX function with optional sharding and mesh configuration.
- Parameters
func – The JAX function to compile.
func_input_args – Input arguments for the function.
func_input_kwargs – Input keyword arguments for the function.
mesh – tp.Optional JAX mesh for distributed execution.
in_shardings – tp.Optional input sharding specifications.
out_shardings – tp.Optional output sharding specifications.
static_argnums – Indices of static arguments.
donate_argnums – Indices of arguments to donate.
- Returns
lowered JAX function.