easydel.utils.traversals#

Utility functions for managing and manipulating nnx module states.

class easydel.utils.traversals.MetaValueRecreator(seed: int = 42)[source]#

Bases: object

Helper class for recreating meta values with state tracking

get_count() Array[source]#
get_rng() PRNGKey[source]#
class easydel.utils.traversals.StateValidationResult(is_valid: bool, missing_keys: set, invalid_types: Dict[str, type])[source]#

Bases: Mapping

from_tuple()#
invalid_types: Dict[str, type]#
is_valid: bool#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
missing_keys: set#
replace(**kwargs)#
to_tuple()#
values() an object providing a view on D's values#
easydel.utils.traversals.create_graphdef(module: Module, _add_rngs: bool = True, _rng_key: str = 'rngs', _seed: int = 0, **kwargs) dict[source]#

Creates a graph definition from an nnx module.

This function initializes the module lazily and extracts the graph definition, which represents the structure of the module without any parameter values.

Parameters
  • module – The nnx module to create the graph definition from.

  • _add_rngs – Whether to add a rngs attribute to the module’s arguments for random number generation. Defaults to True.

  • _rng_key – The key to use for the rngs attribute. Defaults to “rngs”.

  • _seed – The seed value for random number generation. Defaults to 0.

  • **kwargs – Additional keyword arguments to pass to the module’s constructor.

Returns

The graph definition of the module.

Return type

dict

easydel.utils.traversals.deepcopy_model(model)[source]#

Creates a deep copy of a JAX model.

This function takes a JAX model, extracts its leaves (the individual components of the model), deep copies them, and then reconstructs the model with the copied leaves.

Parameters

model – A JAX model to be deep copied. This can be any nested structure of JAX arrays, lists, tuples, dicts, etc.

Returns

A deep copy of the input model with the same structure but with all leaves deep copied.

easydel.utils.traversals.diffrentiate_state(state: Dict[str, Any], init_state: Dict[str, Any], validate: bool = True) Dict[str, VariableState][source]#

Enhanced state differentiation with validation and error handling.

Parameters
  • state – Current state dictionary

  • init_state – Initial state dictionary

  • validate – Whether to perform validation

Returns

Dictionary of missing attributes

Raises

ValueError – If validation fails and validate=True

easydel.utils.traversals.flatten_dict(xs: Union[dict, Mapping], keep_empty_nodes: bool = False, is_leaf: Optional[Callable[[tuple, Any], bool]] = None, sep: Optional[str] = None, fumap: bool = False) Dict[Union[tuple, str], Any][source]#

Enhanced dictionary flattening with better type handling and validation.

Parameters
  • xs – Dictionary or mapping to flatten

  • keep_empty_nodes – Whether to keep empty dictionary nodes

  • is_leaf – Optional function to determine leaf nodes

  • sep – Optional separator for string keys

Returns

Flattened dictionary

Raises

TypeError – If input is not a dictionary or mapping

easydel.utils.traversals.flatten_tree(xs: Dict, is_leaf: Optional[Callable[[Any], bool]] = None, sep: Optional[str] = None) Dict[str, Any][source]#

Flatten a JAX tree and convert paths to strings.

Parameters
  • xs – The JAX tree to flatten.

  • is_leaf – Optional function to determine leaf nodes.

  • sep – Separator to use when joining path elements.

Returns

A flattened dictionary with string keys representing the tree paths.

easydel.utils.traversals.init_garphstate(module: Module, _add_rngs: bool = True, _rng_key: str = 'rngs', _seed: int = 0, _lazy: bool = True, **kwargs) dict[source]#

Initializes the graph state of an nnx module.

This function initializes the module and returns the graph state, which contains the initialized parameter values and other state information.

Parameters
  • module – The nnx module to initialize.

  • _add_rngs – Whether to add a rngs attribute to the module’s arguments for random number generation. Defaults to True.

  • _rng_key – The key to use for the rngs attribute. Defaults to “rngs”.

  • _seed – The seed value for random number generation. Defaults to 0.

  • _lazy – Whether to perform lazy initialization. If True, the module’s parameters will be initialized lazily when first used. Defaults to True.

  • **kwargs – Additional keyword arguments to pass to the module’s constructor.

Returns

The initialized graph state of the module.

Return type

dict

easydel.utils.traversals.int_key_to_string(xs)[source]#
easydel.utils.traversals.is_flatten(tree: dict) bool[source]#

Checks if a dictionary represents a flattened tree.

A flattened tree is a dictionary where the keys are tuples representing the path to the leaf nodes. This function checks if any of the keys in the input dictionary is a tuple, indicating a flattened tree.

Parameters

tree – The dictionary to check.

Returns

True if the dictionary is a flattened tree, False otherwise.

Return type

bool

easydel.utils.traversals.is_iterable(obj)[source]#
easydel.utils.traversals.merge_model_and_tree(model: M, tree: dict) M[source]#

Attaches a parameter tree to an nnx model.

This function takes a parameter tree, which is a dictionary containing parameter values, and attaches it to an existing nnx model. It first splits the nnx model into parameters and other model elements. Then, it flattens the parameter tree and the nnx model’s parameters for easy traversal. For each parameter key in the flattened nnx model, if the corresponding value is not None (indicating an existing parameter), it replaces the value with the corresponding value from the input parameter tree. Finally, it recreates the meta values in the “others” part of the model (which includes things like RNG keys and counts), and then merges the updated parameters and “others” back into a single nnx.Module object.

Parameters
  • tree – The parameter tree to attach.

  • model – The nnx model to attach the tree to.

Returns

The updated nnx model with the attached parameter tree.

Return type

nnx.Module

easydel.utils.traversals.merge_state_and_tree(tree: dict, state: State) State[source]#

Attaches a parameter tree to an nnx state.

This function takes a parameter tree, which is a dictionary containing parameter values, and attaches it to an existing nnx state. It first splits the nnx state into parameters and other state elements. Then, it flattens the parameter tree and the nnx state’s parameters for easy traversal. For each parameter key in the flattened nnx state, if the corresponding value is not None (indicating an existing parameter), it replaces the value with the corresponding value from the input parameter tree. Finally, it recreates the meta values in the “others” part of the state (which includes things like RNG keys and counts), and then merges the updated parameters and “others” back into a single nnx.State object.

Parameters
  • tree – The parameter tree to attach.

  • state – The nnx state to attach the tree to.

Returns

The updated nnx state with the attached parameter tree.

Return type

nnx.State

easydel.utils.traversals.named_tree_map(f: Callable[[str, Any, Any], Any], tree: Dict, *rest: Any, is_leaf: Optional[Callable[[Any], bool]] = None, sep: Optional[str] = None) Dict[source]#

An extended version of jax.tree_util.tree_map.

This function extends jax.tree_util.tree_map by providing the path (as a string) to the current leaf node as an argument to the mapped function f.

Parameters
  • f – The function to apply to each leaf node, taking the path and value as input.

  • tree – The JAX tree to map over.

  • *rest – Additional arguments to be passed to f.

  • is_leaf – Optional function to determine leaf nodes.

  • sep – Separator to use when joining path elements.

Returns

A new tree with the same structure as tree but with the values modified by f.

easydel.utils.traversals.nnx_init(module: Type[M], _add_rngs: bool = True, _rng_key: str = 'rngs', _seed: int = 0, _lazy: bool = True, **kwargs) M[source]#

Initializes an nnx module with lazy initialization support.

This function provides a convenient way to initialize nnx modules while handling random number generation and optional lazy initialization.

Parameters
  • module – The nnx module to initialize.

  • _add_rngs – Whether to add a rngs attribute to the module’s arguments for random number generation. Defaults to True.

  • _rng_key – The key to use for the rngs attribute. Defaults to “rngs”.

  • _seed – The seed value for random number generation. Defaults to 0.

  • _lazy – Whether to perform lazy initialization. If True, the module’s parameters will be initialized lazily when first used. Defaults to True.

  • **kwargs – Additional keyword arguments to pass to the module’s constructor.

Returns

The initialized nnx state.

Return type

nnx.State

easydel.utils.traversals.recreate_meta_values(values: Dict[str, Any], seed: Optional[int] = None) Dict[str, Any][source]#

Enhanced meta value recreation with better state management.

Parameters
  • values – Dictionary of values to recreate

  • seed – Optional seed for random number generation

Returns

Dictionary with recreated meta values

Raises

TypeError – For unexpected value types

easydel.utils.traversals.redefine_state(state: dict, missings: dict[str, flax.nnx.variablelib.VariableState]) dict[source]#

Redefines missing attributes in a state dictionary.

This function takes a state dictionary state and a dictionary missings containing missing attributes. It iterates over the missings dictionary and redefines the missing attributes in the state dictionary based on their type.

Parameters
  • state – The state dictionary to redefine.

  • missings – A dictionary of missing attributes.

Returns

The redefined state dictionary.

Return type

dict

Raises

AttributeError – If an unexpected type is encountered in the missings dictionary.

easydel.utils.traversals.refine_graphs(*graphs: dict) State[source]#

Refines and merges multiple graph representations into a single nnx.State.

This function takes multiple graph representations, which can be either dictionaries or nnx.State instances, and merges them into a single nnx.State object. It ensures that all inputs are converted to nnx.State instances before merging.

Parameters

*graphs – The graph representations to merge.

Returns

The merged nnx.State object.

Return type

nnx.State

easydel.utils.traversals.specs_to_name_sharding(tree: Dict, mesh: Optional[Mesh] = None) Dict[source]#

Converts a dictionary of specifications to a dictionary of NamedSharding objects.

Parameters
  • tree (Dict) – A dictionary where the keys are names and the values are specifications.

  • mesh (Optional[Mesh]) – An optional Mesh object. If not provided, the default physical mesh from pxla.thread_resources.env.physical_mesh is used.

Returns

A dictionary where the keys are the same as the input dictionary, and the values are NamedSharding

objects created from the specifications and the provided or default mesh.

Return type

Dict

easydel.utils.traversals.string_key_to_int(xs)[source]#
easydel.utils.traversals.tree_apply(fns: Dict[Any, Callable[[Any], Any]], tree: Dict[Any, Any]) Dict[Any, Any][source]#

Apply a dictionary of functions to a corresponding PyTree.

Parameters
  • fns – A dictionary where keys match the PyTree structure and values are functions.

  • tree – The PyTree to apply functions to.

Returns

A new PyTree with the same structure as tree, but with values modified by the functions in fns.

easydel.utils.traversals.tree_path_to_string(path: Tuple[Any, ...], sep: Optional[str] = None) str[source]#

Convert a JAX tree path to a string representation.

Parameters
  • path – The JAX tree path tuple.

  • sep – Separator to use when joining path elements.

Returns

The string representation of the path.

easydel.utils.traversals.unflatten_dict(xs, sep=None)[source]#
easydel.utils.traversals.validate_state(state: Dict[str, Any], init_state: Dict[str, Any]) StateValidationResult[source]#

Validates state against init_state before differentiation.