easydel.utils.traversals#
Utility functions for managing and manipulating nnx module states.
- class easydel.utils.traversals.MetaValueRecreator(seed: int = 42)[source]#
Bases:
objectHelper class for recreating meta values with state tracking
- class easydel.utils.traversals.StateValidationResult(is_valid: bool, missing_keys: set, invalid_types: Dict[str, type])[source]#
Bases:
Mapping- from_tuple()#
- invalid_types: Dict[str, type]#
- is_valid: bool#
- items() a set-like object providing a view on D's items#
- keys() a set-like object providing a view on D's keys#
- missing_keys: set#
- replace(**kwargs)#
- to_tuple()#
- values() an object providing a view on D's values#
- easydel.utils.traversals.create_graphdef(module: Module, _add_rngs: bool = True, _rng_key: str = 'rngs', _seed: int = 0, **kwargs) dict[source]#
Creates a graph definition from an nnx module.
This function initializes the module lazily and extracts the graph definition, which represents the structure of the module without any parameter values.
- Parameters
module – The nnx module to create the graph definition from.
_add_rngs – Whether to add a rngs attribute to the module’s arguments for random number generation. Defaults to True.
_rng_key – The key to use for the rngs attribute. Defaults to “rngs”.
_seed – The seed value for random number generation. Defaults to 0.
**kwargs – Additional keyword arguments to pass to the module’s constructor.
- Returns
The graph definition of the module.
- Return type
dict
- easydel.utils.traversals.deepcopy_model(model)[source]#
Creates a deep copy of a JAX model.
This function takes a JAX model, extracts its leaves (the individual components of the model), deep copies them, and then reconstructs the model with the copied leaves.
- Parameters
model – A JAX model to be deep copied. This can be any nested structure of JAX arrays, lists, tuples, dicts, etc.
- Returns
A deep copy of the input model with the same structure but with all leaves deep copied.
- easydel.utils.traversals.diffrentiate_state(state: Dict[str, Any], init_state: Dict[str, Any], validate: bool = True) Dict[str, VariableState][source]#
Enhanced state differentiation with validation and error handling.
- Parameters
state – Current state dictionary
init_state – Initial state dictionary
validate – Whether to perform validation
- Returns
Dictionary of missing attributes
- Raises
ValueError – If validation fails and validate=True
- easydel.utils.traversals.flatten_dict(xs: Union[dict, Mapping], keep_empty_nodes: bool = False, is_leaf: Optional[Callable[[tuple, Any], bool]] = None, sep: Optional[str] = None, fumap: bool = False) Dict[Union[tuple, str], Any][source]#
Enhanced dictionary flattening with better type handling and validation.
- Parameters
xs – Dictionary or mapping to flatten
keep_empty_nodes – Whether to keep empty dictionary nodes
is_leaf – Optional function to determine leaf nodes
sep – Optional separator for string keys
- Returns
Flattened dictionary
- Raises
TypeError – If input is not a dictionary or mapping
- easydel.utils.traversals.flatten_tree(xs: Dict, is_leaf: Optional[Callable[[Any], bool]] = None, sep: Optional[str] = None) Dict[str, Any][source]#
Flatten a JAX tree and convert paths to strings.
- Parameters
xs – The JAX tree to flatten.
is_leaf – Optional function to determine leaf nodes.
sep – Separator to use when joining path elements.
- Returns
A flattened dictionary with string keys representing the tree paths.
- easydel.utils.traversals.init_garphstate(module: Module, _add_rngs: bool = True, _rng_key: str = 'rngs', _seed: int = 0, _lazy: bool = True, **kwargs) dict[source]#
Initializes the graph state of an nnx module.
This function initializes the module and returns the graph state, which contains the initialized parameter values and other state information.
- Parameters
module – The nnx module to initialize.
_add_rngs – Whether to add a rngs attribute to the module’s arguments for random number generation. Defaults to True.
_rng_key – The key to use for the rngs attribute. Defaults to “rngs”.
_seed – The seed value for random number generation. Defaults to 0.
_lazy – Whether to perform lazy initialization. If True, the module’s parameters will be initialized lazily when first used. Defaults to True.
**kwargs – Additional keyword arguments to pass to the module’s constructor.
- Returns
The initialized graph state of the module.
- Return type
dict
- easydel.utils.traversals.is_flatten(tree: dict) bool[source]#
Checks if a dictionary represents a flattened tree.
A flattened tree is a dictionary where the keys are tuples representing the path to the leaf nodes. This function checks if any of the keys in the input dictionary is a tuple, indicating a flattened tree.
- Parameters
tree – The dictionary to check.
- Returns
True if the dictionary is a flattened tree, False otherwise.
- Return type
bool
- easydel.utils.traversals.merge_model_and_tree(model: M, tree: dict) M[source]#
Attaches a parameter tree to an nnx model.
This function takes a parameter tree, which is a dictionary containing parameter values, and attaches it to an existing nnx model. It first splits the nnx model into parameters and other model elements. Then, it flattens the parameter tree and the nnx model’s parameters for easy traversal. For each parameter key in the flattened nnx model, if the corresponding value is not None (indicating an existing parameter), it replaces the value with the corresponding value from the input parameter tree. Finally, it recreates the meta values in the “others” part of the model (which includes things like RNG keys and counts), and then merges the updated parameters and “others” back into a single nnx.Module object.
- Parameters
tree – The parameter tree to attach.
model – The nnx model to attach the tree to.
- Returns
The updated nnx model with the attached parameter tree.
- Return type
nnx.Module
- easydel.utils.traversals.merge_state_and_tree(tree: dict, state: State) State[source]#
Attaches a parameter tree to an nnx state.
This function takes a parameter tree, which is a dictionary containing parameter values, and attaches it to an existing nnx state. It first splits the nnx state into parameters and other state elements. Then, it flattens the parameter tree and the nnx state’s parameters for easy traversal. For each parameter key in the flattened nnx state, if the corresponding value is not None (indicating an existing parameter), it replaces the value with the corresponding value from the input parameter tree. Finally, it recreates the meta values in the “others” part of the state (which includes things like RNG keys and counts), and then merges the updated parameters and “others” back into a single nnx.State object.
- Parameters
tree – The parameter tree to attach.
state – The nnx state to attach the tree to.
- Returns
The updated nnx state with the attached parameter tree.
- Return type
nnx.State
- easydel.utils.traversals.named_tree_map(f: Callable[[str, Any, Any], Any], tree: Dict, *rest: Any, is_leaf: Optional[Callable[[Any], bool]] = None, sep: Optional[str] = None) Dict[source]#
An extended version of jax.tree_util.tree_map.
This function extends jax.tree_util.tree_map by providing the path (as a string) to the current leaf node as an argument to the mapped function f.
- Parameters
f – The function to apply to each leaf node, taking the path and value as input.
tree – The JAX tree to map over.
*rest – Additional arguments to be passed to f.
is_leaf – Optional function to determine leaf nodes.
sep – Separator to use when joining path elements.
- Returns
A new tree with the same structure as tree but with the values modified by f.
- easydel.utils.traversals.nnx_init(module: Type[M], _add_rngs: bool = True, _rng_key: str = 'rngs', _seed: int = 0, _lazy: bool = True, **kwargs) M[source]#
Initializes an nnx module with lazy initialization support.
This function provides a convenient way to initialize nnx modules while handling random number generation and optional lazy initialization.
- Parameters
module – The nnx module to initialize.
_add_rngs – Whether to add a rngs attribute to the module’s arguments for random number generation. Defaults to True.
_rng_key – The key to use for the rngs attribute. Defaults to “rngs”.
_seed – The seed value for random number generation. Defaults to 0.
_lazy – Whether to perform lazy initialization. If True, the module’s parameters will be initialized lazily when first used. Defaults to True.
**kwargs – Additional keyword arguments to pass to the module’s constructor.
- Returns
The initialized nnx state.
- Return type
nnx.State
- easydel.utils.traversals.recreate_meta_values(values: Dict[str, Any], seed: Optional[int] = None) Dict[str, Any][source]#
Enhanced meta value recreation with better state management.
- Parameters
values – Dictionary of values to recreate
seed – Optional seed for random number generation
- Returns
Dictionary with recreated meta values
- Raises
TypeError – For unexpected value types
- easydel.utils.traversals.redefine_state(state: dict, missings: dict[str, flax.nnx.variablelib.VariableState]) dict[source]#
Redefines missing attributes in a state dictionary.
This function takes a state dictionary state and a dictionary missings containing missing attributes. It iterates over the missings dictionary and redefines the missing attributes in the state dictionary based on their type.
- Parameters
state – The state dictionary to redefine.
missings – A dictionary of missing attributes.
- Returns
The redefined state dictionary.
- Return type
dict
- Raises
AttributeError – If an unexpected type is encountered in the missings dictionary.
- easydel.utils.traversals.refine_graphs(*graphs: dict) State[source]#
Refines and merges multiple graph representations into a single nnx.State.
This function takes multiple graph representations, which can be either dictionaries or nnx.State instances, and merges them into a single nnx.State object. It ensures that all inputs are converted to nnx.State instances before merging.
- Parameters
*graphs – The graph representations to merge.
- Returns
The merged nnx.State object.
- Return type
nnx.State
- easydel.utils.traversals.specs_to_name_sharding(tree: Dict, mesh: Optional[Mesh] = None) Dict[source]#
Converts a dictionary of specifications to a dictionary of NamedSharding objects.
- Parameters
tree (Dict) – A dictionary where the keys are names and the values are specifications.
mesh (Optional[Mesh]) – An optional Mesh object. If not provided, the default physical mesh from pxla.thread_resources.env.physical_mesh is used.
- Returns
- A dictionary where the keys are the same as the input dictionary, and the values are NamedSharding
objects created from the specifications and the provided or default mesh.
- Return type
Dict
- easydel.utils.traversals.tree_apply(fns: Dict[Any, Callable[[Any], Any]], tree: Dict[Any, Any]) Dict[Any, Any][source]#
Apply a dictionary of functions to a corresponding PyTree.
- Parameters
fns – A dictionary where keys match the PyTree structure and values are functions.
tree – The PyTree to apply functions to.
- Returns
A new PyTree with the same structure as tree, but with values modified by the functions in fns.
- easydel.utils.traversals.tree_path_to_string(path: Tuple[Any, ...], sep: Optional[str] = None) str[source]#
Convert a JAX tree path to a string representation.
- Parameters
path – The JAX tree path tuple.
sep – Separator to use when joining path elements.
- Returns
The string representation of the path.
- easydel.utils.traversals.validate_state(state: Dict[str, Any], init_state: Dict[str, Any]) StateValidationResult[source]#
Validates state against init_state before differentiation.