Source code for eformer.pytree._pytree

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import dataclasses
import json
import types
import typing as tp
from functools import wraps, lru_cache

import typing_extensions
from jax import tree_util as tu

T = tp.TypeVar("T")
FnDict = tp.Dict[tp.Any, tp.Callable[[tp.Any], tp.Any]]
TreeDict = tp.Dict[tp.Any, tp.Any]
Path = tp.Tuple[tp.Any, ...]
FilterSpec = tp.Union[bool, tp.Callable[[tp.Any], bool]]
IsLeafFn = tp.Callable[[tp.Any], bool]


@lru_cache(maxsize=1024)
def _is_non_jax_type(typ: tp.Type) -> bool:
	"""
	Checks if a given type is considered a non-JAX type.

	Non-JAX types are typically those that should not be treated as leaves
	in a PyTree structure by default, such as strings, functions, or types.

	Args:
	    typ: The type to check.

	Returns:
	    True if the type is considered a non-JAX type, False otherwise.
	"""
	NON_JAX_TYPES = (
		str,
		bytes,
		types.FunctionType,
		types.MethodType,
		type,
		tp.Callable,
	)

	if typ is tp.Any:
		return False

	origin = tp.get_origin(typ)
	if origin is tp.Union:
		args = tp.get_args(typ)
		return any(_is_non_jax_type(arg) for arg in args)

	for non_jax_type in NON_JAX_TYPES:
		try:
			if issubclass(typ, non_jax_type):
				return True
		except TypeError:
			# issubclass raises TypeError if typ is not a class (e.g., tp.Callable)
			pass

	return False


def field(
	pytree_node: bool = True, *, metadata: tp.Optional[tp.Dict] = None, **kwargs
) -> dataclasses.Field:
	"""
	A dataclass field replacement that allows specifying whether a field
	should be treated as a PyTree node.

	Args:
	    pytree_node: If True (default), the field is treated as a PyTree leaf/node.
	                 If False, the field is treated as metadata.
	    metadata: Optional dictionary of metadata to pass to `dataclasses.field`.
	    **kwargs: Additional keyword arguments passed to `dataclasses.field`.

	Returns:
	    A dataclasses.Field object.
	"""
	metadata_dict = (metadata or {}).copy()
	metadata_dict["pytree_node"] = pytree_node
	return dataclasses.field(metadata=metadata_dict, **kwargs)


class PyTreeClassInfo:
	"""Stores metadata about a class registered as a PyTree."""

	__slots__ = ["data_fields", "meta_fields", "frozen", "type_hints"]

	def __init__(
		self,
		data_fields: tp.Tuple[str, ...],
		meta_fields: tp.Tuple[str, ...],
		frozen: bool,
		type_hints: tp.Dict[str, tp.Type],
	):
		"""
		Initializes the PyTreeClassInfo.

		Args:
		    data_fields: Tuple of field names treated as PyTree data (children).
		    meta_fields: Tuple of field names treated as PyTree metadata.
		    frozen: Boolean indicating if the original dataclass was frozen.
		    type_hints: Dictionary mapping field names to their type hints.
		"""
		self.data_fields = data_fields
		self.meta_fields = meta_fields
		self.frozen = frozen
		self.type_hints = type_hints


_CLASS_INFO_REGISTRY: tp.Dict[tp.Type, PyTreeClassInfo] = {}


[docs]@typing_extensions.dataclass_transform(field_specifiers=(field,)) def auto_pytree( cls: tp.Optional[tp.Type[T]] = None, meta_fields: tp.Optional[tp.Tuple[str, ...]] = None, json_serializable: bool = True, frozen: bool = False, ): """ A class decorator that automatically registers a dataclass as a JAX PyTree. It uses `dataclasses.dataclass` to make the class a dataclass if it isn't already, determines which fields are data (PyTree children) and which are metadata, and registers the class with `jax.tree_util.register_dataclass`. Fields are considered metadata if: - They are explicitly listed in `meta_fields`. - They are marked with `field(pytree_node=False)`. - Their type hint suggests they are non-JAX types (checked by `_is_non_jax_type`). Args: cls: The class to be decorated. meta_fields: A tuple of field names to always treat as metadata. json_serializable: If True (default), adds `to_dict`, `from_dict`, `to_json`, and `from_json` methods to the class. frozen: If True, makes the dataclass frozen (immutable). Defaults to False. Returns: The decorated class, registered as a PyTree. """ def wrap(cls_inner: tp.Type[T]) -> tp.Type[T]: """Internal wrapper function for the decorator.""" cls_inner = dataclasses.dataclass(cls_inner, frozen=frozen) fields = [f for f in dataclasses.fields(cls_inner) if f.init] all_field_names = tuple(f.name for f in fields) final_meta_fields: tp.Set[str] = set(meta_fields or ()) # Determine meta fields based on field metadata for field_obj in fields: field_metadata = field_obj.metadata if field_metadata and "pytree_node" in field_metadata: if field_metadata["pytree_node"] is False: final_meta_fields.add(field_obj.name) elif ( field_metadata["pytree_node"] is True and field_obj.name in final_meta_fields ): # Explicitly marked as node, overrides meta_fields argument final_meta_fields.remove(field_obj.name) type_hints = tp.get_type_hints(cls_inner) # Determine meta fields based on type hints for field_obj in fields: if field_obj.name in final_meta_fields: continue # Already marked as meta if field_obj.metadata and field_obj.metadata.get("pytree_node") is True: continue # Explicitly marked as node field_type = type_hints.get(field_obj.name) if field_type is not None and _is_non_jax_type(field_type): final_meta_fields.add(field_obj.name) data_fields = tuple(f for f in all_field_names if f not in final_meta_fields) meta_fields_tuple = tuple(final_meta_fields) # Add a replace method similar to flax.struct.replace def _replace(self, **kwargs): """Creates a new instance with specified fields replaced.""" if not kwargs: return self return dataclasses.replace(self, **kwargs) cls_inner.replace = _replace # Add an enhanced repr def enhanced_repr(self): """Provides a more detailed representation of the object.""" cls_name = self.__class__.__name__ items = [] for k, v in self.__dict__.items(): if not k.startswith("_"): # Avoid private/internal attributes try: repr_str = str(v) if len(repr_str) > 200: # Truncate long representations repr_str = f"{v.__class__.__name__}(...)" items.append(f" {k} : {repr_str}") except TypeError: # Handle cases where str() might fail items.append(f" {k} : <unrepresentable>") return f"{cls_name}(\n" + "\n".join(items) + "\n)" cls_inner.__repr__ = enhanced_repr cls_inner.__str__ = enhanced_repr # Use the same for str # Store class info for potential later use (e.g., serialization) class_info = PyTreeClassInfo( data_fields=data_fields, meta_fields=meta_fields_tuple, frozen=frozen, type_hints=type_hints, ) _CLASS_INFO_REGISTRY[cls_inner] = class_info # Store basic info directly on the class for easier access if needed cls_inner.__pytree_meta__ = { "data_fields": data_fields, "meta_fields": meta_fields_tuple, "frozen": frozen, } # Add JSON serialization methods if requested if json_serializable: def to_dict(self) -> tp.Dict[str, tp.Any]: """Serializes the PyTree object to a dictionary.""" result = {} for field_obj in dataclasses.fields(self): value = getattr(self, field_obj.name) if value is Ellipsis: # Skip Ellipsis sentinel values continue # Basic type handling for JSON compatibility if isinstance(value, tuple): result[field_obj.name] = list(value) # Convert tuples to lists elif value is None: result[field_obj.name] = None elif hasattr(value, "to_dict") and callable(value.to_dict): # Recursively call to_dict if available result[field_obj.name] = value.to_dict() else: # Attempt direct JSON serialization, fallback to str try: json.dumps(value) # Check if serializable result[field_obj.name] = value except (TypeError, OverflowError): result[field_obj.name] = str(value) return result cls_inner.to_dict = to_dict @classmethod def from_dict(cls_inner_classmethod: tp.Type[T], data: tp.Dict[str, tp.Any]) -> T: """Deserializes a dictionary into a PyTree object.""" processed_data = {} # Retrieve type hints for potential type conversions class_info_local = _CLASS_INFO_REGISTRY.get(cls_inner_classmethod) type_hints_local = ( class_info_local.type_hints if class_info_local else tp.get_type_hints(cls_inner_classmethod) ) for field_obj in dataclasses.fields(cls_inner_classmethod): field_name = field_obj.name if field_name not in data: continue # Skip fields not present in the data value = data[field_name] field_type = type_hints_local.get(field_name) # Handle specific type conversions (e.g., list back to tuple) if ( value is not None and isinstance(value, list) and field_type is not None and tp.get_origin(field_type) is tuple # Check if original type was tuple ): processed_data[field_name] = tuple(value) # TODO: Add more robust deserialization for nested PyTrees if needed # elif hasattr(field_type, "from_dict") and callable(field_type.from_dict): # processed_data[field_name] = field_type.from_dict(value) else: processed_data[field_name] = value return cls_inner_classmethod(**processed_data) cls_inner.from_dict = from_dict def to_json(self, **kwargs) -> str: """Serializes the PyTree object to a JSON string.""" return json.dumps(self.to_dict(), **kwargs) cls_inner.to_json = to_json @classmethod def from_json(cls_inner_classmethod: tp.Type[T], json_str: str) -> T: """Deserializes a JSON string into a PyTree object.""" data = json.loads(json_str) return cls_inner_classmethod.from_dict(data) cls_inner.from_json = from_json # Patch json.JSONEncoder to handle PyTree objects by default if not hasattr(json.JSONEncoder, "_pytree_patched"): original_default = json.JSONEncoder.default @wraps(original_default) def json_default(encoder_self, obj): """JSON encoder default method patched to handle PyTrees.""" if hasattr(obj, "to_dict") and callable(obj.to_dict): return obj.to_dict() return original_default(encoder_self, obj) json.JSONEncoder.default = json_default json.JSONEncoder._pytree_patched = True # Mark as patched # Register the class with JAX return tu.register_dataclass( cls_inner, data_fields=data_fields, meta_fields=meta_fields_tuple, ) # Handle decorator usage with or without arguments if cls is None: return wrap return wrap(cls)
class _PyTreeNodeBase: """Base class providing a default `replace` method.""" def replace(self: T, **kwargs) -> T: """Creates a new instance with specified fields replaced. This method is typically overridden by the `auto_pytree` decorator if the class is decorated directly. It serves as a fallback or base for classes inheriting from PyTree/FrozenPyTree. Args: **kwargs: Field names and their new values. Returns: A new instance of the class with the specified fields updated. """ return dataclasses.replace(self, **kwargs)
[docs]@typing_extensions.dataclass_transform(field_specifiers=(field,)) class PyTree(_PyTreeNodeBase): """ Base class for mutable PyTree dataclasses. Inheriting from this class automatically applies the `auto_pytree` decorator to the subclass, registering it as a JAX PyTree. """ def __init_subclass__( cls, *, frozen: bool = False, # Keep frozen option, though typically False for PyTree json_serializable: bool = True, meta_fields: tp.Optional[tp.Tuple[str, ...]] = None, **kwargs, ): """ Applies `auto_pytree` to subclasses. Args: frozen: If True, makes the dataclass frozen. Defaults to False. json_serializable: If True (default), adds JSON serialization methods. meta_fields: Tuple of field names to always treat as metadata. **kwargs: Additional arguments passed to `auto_pytree`. """ super().__init_subclass__(**kwargs) auto_pytree( cls, meta_fields=meta_fields, json_serializable=json_serializable, frozen=frozen, # Pass frozen status )
@typing_extensions.dataclass_transform(field_specifiers=(field,)) class FrozenPyTree(_PyTreeNodeBase): """ Base class for immutable (frozen) PyTree dataclasses. Inheriting from this class automatically applies the `auto_pytree` decorator with `frozen=True` to the subclass, registering it as a frozen JAX PyTree. """ def __init_subclass__( cls, *, json_serializable: bool = True, meta_fields: tp.Optional[tp.Tuple[str, ...]] = None, **kwargs, ): """ Applies `auto_pytree` with frozen=True to subclasses. Args: json_serializable: If True (default), adds JSON serialization methods. meta_fields: Tuple of field names to always treat as metadata. **kwargs: Additional arguments passed to `auto_pytree`. """ super().__init_subclass__(**kwargs) auto_pytree( cls, meta_fields=meta_fields, json_serializable=json_serializable, frozen=True, # Ensure subclasses are frozen )