# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
import json
import types
import typing as tp
from functools import wraps, lru_cache
import typing_extensions
from jax import tree_util as tu
T = tp.TypeVar("T")
FnDict = tp.Dict[tp.Any, tp.Callable[[tp.Any], tp.Any]]
TreeDict = tp.Dict[tp.Any, tp.Any]
Path = tp.Tuple[tp.Any, ...]
FilterSpec = tp.Union[bool, tp.Callable[[tp.Any], bool]]
IsLeafFn = tp.Callable[[tp.Any], bool]
@lru_cache(maxsize=1024)
def _is_non_jax_type(typ: tp.Type) -> bool:
"""
Checks if a given type is considered a non-JAX type.
Non-JAX types are typically those that should not be treated as leaves
in a PyTree structure by default, such as strings, functions, or types.
Args:
typ: The type to check.
Returns:
True if the type is considered a non-JAX type, False otherwise.
"""
NON_JAX_TYPES = (
str,
bytes,
types.FunctionType,
types.MethodType,
type,
tp.Callable,
)
if typ is tp.Any:
return False
origin = tp.get_origin(typ)
if origin is tp.Union:
args = tp.get_args(typ)
return any(_is_non_jax_type(arg) for arg in args)
for non_jax_type in NON_JAX_TYPES:
try:
if issubclass(typ, non_jax_type):
return True
except TypeError:
# issubclass raises TypeError if typ is not a class (e.g., tp.Callable)
pass
return False
def field(
pytree_node: bool = True, *, metadata: tp.Optional[tp.Dict] = None, **kwargs
) -> dataclasses.Field:
"""
A dataclass field replacement that allows specifying whether a field
should be treated as a PyTree node.
Args:
pytree_node: If True (default), the field is treated as a PyTree leaf/node.
If False, the field is treated as metadata.
metadata: Optional dictionary of metadata to pass to `dataclasses.field`.
**kwargs: Additional keyword arguments passed to `dataclasses.field`.
Returns:
A dataclasses.Field object.
"""
metadata_dict = (metadata or {}).copy()
metadata_dict["pytree_node"] = pytree_node
return dataclasses.field(metadata=metadata_dict, **kwargs)
class PyTreeClassInfo:
"""Stores metadata about a class registered as a PyTree."""
__slots__ = ["data_fields", "meta_fields", "frozen", "type_hints"]
def __init__(
self,
data_fields: tp.Tuple[str, ...],
meta_fields: tp.Tuple[str, ...],
frozen: bool,
type_hints: tp.Dict[str, tp.Type],
):
"""
Initializes the PyTreeClassInfo.
Args:
data_fields: Tuple of field names treated as PyTree data (children).
meta_fields: Tuple of field names treated as PyTree metadata.
frozen: Boolean indicating if the original dataclass was frozen.
type_hints: Dictionary mapping field names to their type hints.
"""
self.data_fields = data_fields
self.meta_fields = meta_fields
self.frozen = frozen
self.type_hints = type_hints
_CLASS_INFO_REGISTRY: tp.Dict[tp.Type, PyTreeClassInfo] = {}
[docs]@typing_extensions.dataclass_transform(field_specifiers=(field,))
def auto_pytree(
cls: tp.Optional[tp.Type[T]] = None,
meta_fields: tp.Optional[tp.Tuple[str, ...]] = None,
json_serializable: bool = True,
frozen: bool = False,
):
"""
A class decorator that automatically registers a dataclass as a JAX PyTree.
It uses `dataclasses.dataclass` to make the class a dataclass if it isn't already,
determines which fields are data (PyTree children) and which are metadata,
and registers the class with `jax.tree_util.register_dataclass`.
Fields are considered metadata if:
- They are explicitly listed in `meta_fields`.
- They are marked with `field(pytree_node=False)`.
- Their type hint suggests they are non-JAX types (checked by `_is_non_jax_type`).
Args:
cls: The class to be decorated.
meta_fields: A tuple of field names to always treat as metadata.
json_serializable: If True (default), adds `to_dict`, `from_dict`, `to_json`,
and `from_json` methods to the class.
frozen: If True, makes the dataclass frozen (immutable). Defaults to False.
Returns:
The decorated class, registered as a PyTree.
"""
def wrap(cls_inner: tp.Type[T]) -> tp.Type[T]:
"""Internal wrapper function for the decorator."""
cls_inner = dataclasses.dataclass(cls_inner, frozen=frozen)
fields = [f for f in dataclasses.fields(cls_inner) if f.init]
all_field_names = tuple(f.name for f in fields)
final_meta_fields: tp.Set[str] = set(meta_fields or ())
# Determine meta fields based on field metadata
for field_obj in fields:
field_metadata = field_obj.metadata
if field_metadata and "pytree_node" in field_metadata:
if field_metadata["pytree_node"] is False:
final_meta_fields.add(field_obj.name)
elif (
field_metadata["pytree_node"] is True and field_obj.name in final_meta_fields
):
# Explicitly marked as node, overrides meta_fields argument
final_meta_fields.remove(field_obj.name)
type_hints = tp.get_type_hints(cls_inner)
# Determine meta fields based on type hints
for field_obj in fields:
if field_obj.name in final_meta_fields:
continue # Already marked as meta
if field_obj.metadata and field_obj.metadata.get("pytree_node") is True:
continue # Explicitly marked as node
field_type = type_hints.get(field_obj.name)
if field_type is not None and _is_non_jax_type(field_type):
final_meta_fields.add(field_obj.name)
data_fields = tuple(f for f in all_field_names if f not in final_meta_fields)
meta_fields_tuple = tuple(final_meta_fields)
# Add a replace method similar to flax.struct.replace
def _replace(self, **kwargs):
"""Creates a new instance with specified fields replaced."""
if not kwargs:
return self
return dataclasses.replace(self, **kwargs)
cls_inner.replace = _replace
# Add an enhanced repr
def enhanced_repr(self):
"""Provides a more detailed representation of the object."""
cls_name = self.__class__.__name__
items = []
for k, v in self.__dict__.items():
if not k.startswith("_"): # Avoid private/internal attributes
try:
repr_str = str(v)
if len(repr_str) > 200: # Truncate long representations
repr_str = f"{v.__class__.__name__}(...)"
items.append(f" {k} : {repr_str}")
except TypeError:
# Handle cases where str() might fail
items.append(f" {k} : <unrepresentable>")
return f"{cls_name}(\n" + "\n".join(items) + "\n)"
cls_inner.__repr__ = enhanced_repr
cls_inner.__str__ = enhanced_repr # Use the same for str
# Store class info for potential later use (e.g., serialization)
class_info = PyTreeClassInfo(
data_fields=data_fields,
meta_fields=meta_fields_tuple,
frozen=frozen,
type_hints=type_hints,
)
_CLASS_INFO_REGISTRY[cls_inner] = class_info
# Store basic info directly on the class for easier access if needed
cls_inner.__pytree_meta__ = {
"data_fields": data_fields,
"meta_fields": meta_fields_tuple,
"frozen": frozen,
}
# Add JSON serialization methods if requested
if json_serializable:
def to_dict(self) -> tp.Dict[str, tp.Any]:
"""Serializes the PyTree object to a dictionary."""
result = {}
for field_obj in dataclasses.fields(self):
value = getattr(self, field_obj.name)
if value is Ellipsis: # Skip Ellipsis sentinel values
continue
# Basic type handling for JSON compatibility
if isinstance(value, tuple):
result[field_obj.name] = list(value) # Convert tuples to lists
elif value is None:
result[field_obj.name] = None
elif hasattr(value, "to_dict") and callable(value.to_dict):
# Recursively call to_dict if available
result[field_obj.name] = value.to_dict()
else:
# Attempt direct JSON serialization, fallback to str
try:
json.dumps(value) # Check if serializable
result[field_obj.name] = value
except (TypeError, OverflowError):
result[field_obj.name] = str(value)
return result
cls_inner.to_dict = to_dict
@classmethod
def from_dict(cls_inner_classmethod: tp.Type[T], data: tp.Dict[str, tp.Any]) -> T:
"""Deserializes a dictionary into a PyTree object."""
processed_data = {}
# Retrieve type hints for potential type conversions
class_info_local = _CLASS_INFO_REGISTRY.get(cls_inner_classmethod)
type_hints_local = (
class_info_local.type_hints
if class_info_local
else tp.get_type_hints(cls_inner_classmethod)
)
for field_obj in dataclasses.fields(cls_inner_classmethod):
field_name = field_obj.name
if field_name not in data:
continue # Skip fields not present in the data
value = data[field_name]
field_type = type_hints_local.get(field_name)
# Handle specific type conversions (e.g., list back to tuple)
if (
value is not None
and isinstance(value, list)
and field_type is not None
and tp.get_origin(field_type) is tuple # Check if original type was tuple
):
processed_data[field_name] = tuple(value)
# TODO: Add more robust deserialization for nested PyTrees if needed
# elif hasattr(field_type, "from_dict") and callable(field_type.from_dict):
# processed_data[field_name] = field_type.from_dict(value)
else:
processed_data[field_name] = value
return cls_inner_classmethod(**processed_data)
cls_inner.from_dict = from_dict
def to_json(self, **kwargs) -> str:
"""Serializes the PyTree object to a JSON string."""
return json.dumps(self.to_dict(), **kwargs)
cls_inner.to_json = to_json
@classmethod
def from_json(cls_inner_classmethod: tp.Type[T], json_str: str) -> T:
"""Deserializes a JSON string into a PyTree object."""
data = json.loads(json_str)
return cls_inner_classmethod.from_dict(data)
cls_inner.from_json = from_json
# Patch json.JSONEncoder to handle PyTree objects by default
if not hasattr(json.JSONEncoder, "_pytree_patched"):
original_default = json.JSONEncoder.default
@wraps(original_default)
def json_default(encoder_self, obj):
"""JSON encoder default method patched to handle PyTrees."""
if hasattr(obj, "to_dict") and callable(obj.to_dict):
return obj.to_dict()
return original_default(encoder_self, obj)
json.JSONEncoder.default = json_default
json.JSONEncoder._pytree_patched = True # Mark as patched
# Register the class with JAX
return tu.register_dataclass(
cls_inner,
data_fields=data_fields,
meta_fields=meta_fields_tuple,
)
# Handle decorator usage with or without arguments
if cls is None:
return wrap
return wrap(cls)
class _PyTreeNodeBase:
"""Base class providing a default `replace` method."""
def replace(self: T, **kwargs) -> T:
"""Creates a new instance with specified fields replaced.
This method is typically overridden by the `auto_pytree` decorator
if the class is decorated directly. It serves as a fallback or
base for classes inheriting from PyTree/FrozenPyTree.
Args:
**kwargs: Field names and their new values.
Returns:
A new instance of the class with the specified fields updated.
"""
return dataclasses.replace(self, **kwargs)
[docs]@typing_extensions.dataclass_transform(field_specifiers=(field,))
class PyTree(_PyTreeNodeBase):
"""
Base class for mutable PyTree dataclasses.
Inheriting from this class automatically applies the `auto_pytree`
decorator to the subclass, registering it as a JAX PyTree.
"""
def __init_subclass__(
cls,
*,
frozen: bool = False, # Keep frozen option, though typically False for PyTree
json_serializable: bool = True,
meta_fields: tp.Optional[tp.Tuple[str, ...]] = None,
**kwargs,
):
"""
Applies `auto_pytree` to subclasses.
Args:
frozen: If True, makes the dataclass frozen. Defaults to False.
json_serializable: If True (default), adds JSON serialization methods.
meta_fields: Tuple of field names to always treat as metadata.
**kwargs: Additional arguments passed to `auto_pytree`.
"""
super().__init_subclass__(**kwargs)
auto_pytree(
cls,
meta_fields=meta_fields,
json_serializable=json_serializable,
frozen=frozen, # Pass frozen status
)
@typing_extensions.dataclass_transform(field_specifiers=(field,))
class FrozenPyTree(_PyTreeNodeBase):
"""
Base class for immutable (frozen) PyTree dataclasses.
Inheriting from this class automatically applies the `auto_pytree`
decorator with `frozen=True` to the subclass, registering it as a
frozen JAX PyTree.
"""
def __init_subclass__(
cls,
*,
json_serializable: bool = True,
meta_fields: tp.Optional[tp.Tuple[str, ...]] = None,
**kwargs,
):
"""
Applies `auto_pytree` with frozen=True to subclasses.
Args:
json_serializable: If True (default), adds JSON serialization methods.
meta_fields: Tuple of field names to always treat as metadata.
**kwargs: Additional arguments passed to `auto_pytree`.
"""
super().__init_subclass__(**kwargs)
auto_pytree(
cls,
meta_fields=meta_fields,
json_serializable=json_serializable,
frozen=True, # Ensure subclasses are frozen
)