easydel.infra.base_state#

State management for EasyDeL models.

This module provides the EasyDeLState class, which encapsulates all stateful components of a model during training or inference, including parameters, optimizer state, and training metadata.

Classes:

EasyDeLState: Complete state container for models during training/inference

Key Features:
  • Unified state management for training and inference

  • Automatic sharding and partitioning support

  • Checkpoint saving and loading

  • Gradient application with optimizer integration

  • State serialization and deserialization

Example

>>> from easydel.infra import EasyDeLState
>>> import optax
>>>
>>> # Create state from a model
>>> state = EasyDeLState.create(
...     model=model,
...     tx=optax.adamw(learning_rate=1e-4),
...     init_opt_state=True
... )
>>>
>>> # Apply gradients
>>> state = state.apply_gradients(grads=gradients)
>>>
>>> # Save checkpoint
>>> state.save_state("checkpoint_path")
>>>
>>> # Load checkpoint
>>> state = EasyDeLState.load_state(
...     "checkpoint_path",
...     config=config
... )
class easydel.infra.base_state.EasyDeLState(step: int | jax.Array, graphdef: nn.GraphDef, graphstate: nn.GraphState, graphother: nn.GraphState, tx: optax.GradientTransformation, opt_state: optax.OptState | None, apply_fn: tp.Callable | None = None)[source]#

Bases: PyTreeNode

Complete state container for EasyDeL models.

Encapsulates all stateful components needed for training or inference, including model parameters, optimizer state, and training metadata. Provides methods for gradient updates, checkpointing, and state management.

This class is designed to work seamlessly with JAX’s functional programming paradigm while providing convenient methods for state manipulation.

step#

Current training step count.

Type

int | jax.Array

graphdef#

Model’s computation graph structure (non-pytree).

Type

nn.GraphDef

graphstate#

Model parameter state (pytree).

Type

nn.GraphState

graphother#

Non-parameter model state (pytree).

Type

nn.GraphState

tx#

Optimizer transformation (non-pytree).

Type

optax.GradientTransformation

opt_state#

Optimizer state like moments (pytree).

Type

optax.OptState | None

apply_fn#

Optional model application function (non-pytree).

Type

tp.Callable | None

apply_gradients()[source]#

Update parameters with gradients

create()[source]#

Factory method to create state from model

save_state()[source]#

Save state to checkpoint

load_state()[source]#

Load state from checkpoint

shard_state()[source]#

Apply sharding to state

gather_state()[source]#

Gather sharded state

Example

>>> state = EasyDeLState.create(
...     model=my_model,
...     tx=optax.adam(1e-3)
... )
>>> # Training step
>>> grads = compute_gradients(...)
>>> state = state.apply_gradients(grads=grads)
apply_fn: tp.Callable | None = None#
apply_gradients(*, grads)[source]#

Apply gradients to update parameters and optimizer state.

Performs a single optimization step using the provided gradients.

Parameters

grads – Gradient pytree matching the structure of graphstate.

Returns

New EasyDeLState with updated parameters, optimizer state, and incremented step count.

Raises

AssertionError – If opt_state or tx is not initialized.

Example

>>> grads = jax.grad(loss_fn)(state.graphstate)
>>> state = state.apply_gradients(grads=grads)
classmethod create(*, step: int | None = None, graphdef: nn.GraphDef | None = None, graphstate: nn.GraphState | None = None, graphother: nn.GraphState | None = None, model: nn.Module | None = None, tx: optax.GradientTransformation | None = None, opt_state: optax.OptState | None = None, init_opt_state: bool = False) Self[source]#

Creates a new EasyDeLState instance.

This class method provides a flexible way to initialize the state, either from an existing nn.Module or by providing the graph components (graphdef, graphstate, graphother) directly. It also handles optimizer state initialization.

Parameters
  • step (tp.Optional[int]) – The initial training step. Defaults to 0.

  • graphdef (tp.Optional[nn.GraphDef]) – The model’s graph definition.

  • graphstate (tp.Optional[nn.GraphState]) – The model’s parameter state.

  • graphother (tp.Optional[nn.GraphState]) – The model’s non-parameter state.

  • model (tp.Optional[nn.Module]) – An EasyDeL module instance. If provided, graphdef, graphstate, and graphother are derived from it. Cannot be provided simultaneously with graph components.

  • tx (tp.Optional[optax.GradientTransformation]) – The optimizer transformation.

  • opt_state (tp.Optional[optax.OptState]) – The initial optimizer state. Cannot be provided if init_opt_state is True.

  • init_opt_state (bool) – If True, initializes the optimizer state using tx.init(graphstate). Requires tx to be provided. Defaults to False.

Returns

A new instance of the state.

Return type

EasyDeLState

Raises
  • ValueError – If model and graph components are provided simultaneously.

  • ValueError – If graph components are provided partially.

  • ValueError – If init_opt_state is True and opt_state is also provided.

  • ValueError – If init_opt_state is True but tx is not provided.

gather_model(partition_rules: PartitionLike = None, mesh: Mesh | None = None) Self[source]#

Gathers the model parameters (graphstate and graphother) from distributed devices.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules used for the original sharding. If None, uses model config rules. Defaults to None.

  • mesh (tp.Optional[Mesh], optional) – The JAX device mesh to gather from. If None, uses model’s mesh. Defaults to None.

Returns

A new state object with gathered graphstate and graphother.

Return type

EasyDeLState

gather_optimizer_state(partition_rules=None)[source]#

Gathers the optimizer state from potentially distributed devices.

Parameters

partition_rules (PartitionLike, optional) – Partitioning rules used to determine how the state was sharded. If None, uses rules from the model’s config. Defaults to None.

Returns

A new state object with the gathered opt_state.

Return type

EasyDeLState

Raises

AssertionError – If opt_state is not initialized.

gather_state()[source]#

Gathers the entire state from distributed devices.

This is a convenience method that calls gather_model and gather_optimizer_state.

Returns

A new state object with both model and optimizer states gathered.

Return type

EasyDeLState

graphdef: nn.GraphDef#
graphother: nn.GraphState#
graphstate: nn.GraphState#
init_tx(tx: optax.GradientTransformation, partition_rules: PartitionLike = None) Self[source]#

Initializes the optimizer state (opt_state) for the current graphstate using the provided optimizer transformation (tx).

It automatically handles sharding based on the model’s partition rules.

Parameters
  • tx (optax.GradientTransformation) – The optimizer transformation to initialize with.

  • partition_rules (PartitionLike, optional) – Partitioning rules for the optimizer state. If None, uses the rules from the associated model’s config. Defaults to None.

Returns

A new state object with the initialized and potentially

sharded opt_state and the provided tx.

Return type

EasyDeLState

load_optimizer(load_directory: str | eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath, checkpointer: eformer.serialization.checkpointer.Checkpointer | None = None, tx_template: None | optax._src.base.GradientTransformation = None)[source]#

Loads the optimizer state from saved files.

Reads the optimizer state structure from a pickle file (OPTIMIZER_STRUCT_NAME) and the tensor data from a SafeTensors file (OPTIMIZER_NAME) within the specified directory.

Parameters

load_directory (tp.Union[str, os.PathLike]) – The directory containing the saved optimizer state files.

Returns

A new state object with the loaded opt_state.

Return type

EasyDeLState

Raises
  • FileNotFoundError – If the required optimizer files are not found.

  • Exception – If any error occurs during loading or deserialization.

classmethod load_state(load_directory: str | os.PathLike, device: jax.Device | None = 'cpu', dtype: jnp.dtype = <class 'jax.numpy.bfloat16'>, param_dtype: jnp.dtype = <class 'jax.numpy.bfloat16'>, precision: jax.lax.Precision | None = None, sharding_axis_dims: tp.Sequence[int] = (1, -1, 1, 1, 1), sharding_dcn_axis_dims: tp.Sequence[int] | None = None, sharding_axis_names: tp.Sequence[str] = ('dp', 'fsdp', 'ep', 'tp', 'sp'), partition_axis: PartitionAxis | None = None, shard_fns: tp.Mapping[tuple, tp.Callable] | dict | None = None, backend: EasyDeLBackends | None = None, platform: EasyDeLPlatforms | None = None, config_kwargs: EasyDeLBaseConfigDict | None = None, model_task: TaskType = TaskType.AUTO_BIND, auto_shard_model: bool = True, partition_rules: tuple[tuple[str, PartitionSpec], ...] | None = None, quantization_config: EasyDeLQuantizationConfig | None = None, quantize_tensors: bool = True, verbose: bool = True, tx_template: optax.GradientTransformation | None = None, **kwargs)[source]#

Loads an EasyDeLState from a saved checkpoint directory.

This class method reconstructs the model configuration, loads the model parameters, and optionally loads the optimizer state from files saved previously using save_state. It handles various configurations for device placement, data types, sharding, and quantization.

Parameters
  • load_directory – Path to the directory containing the saved state (configuration, model weights, and potentially optimizer state).

  • device – The JAX device (e.g., ‘cpu’, ‘gpu’, ‘tpu’) to load the model onto. Defaults to ‘cpu’.

  • dtype – The data type to use for computation (e.g., jnp.bfloat16). Defaults to jnp.bfloat16.

  • param_dtype – The data type for the model parameters (e.g., jnp.bfloat16). Defaults to jnp.bfloat16.

  • precision – The JAX precision level (e.g., jax.lax.Precision.HIGHEST). Defaults to None.

  • sharding_axis_dims – A sequence defining the dimensions of the device mesh for sharding (e.g., (1, -1, 1, 1, 1)). Defaults to (1, -1, 1, 1, 1).

  • sharding_dcn_axis_dims – Optional sequence for data-centric sharding dimensions. Defaults to None.

  • sharding_axis_names – Names corresponding to the sharding axes (e.g., (“dp”, “fsdp”, “ep”, “tp”, “sp”)). Defaults to (“dp”, “fsdp”, “ep”, “tp”, “sp”).

  • partition_axis – Configuration object for partitioning specific axes. Defaults to None.

  • shard_fns – Optional mapping of parameter path tuples to custom sharding functions. Defaults to None.

  • backend – The backend framework to use (e.g., EasyDeLBackends.JAX). Defaults to None (auto-detected).

  • platform – The hardware platform (e.g., EasyDeLPlatforms.TPU). Defaults to None (auto-detected).

  • config_kwargs – Optional dictionary of keyword arguments to override in the loaded model configuration. Defaults to None.

  • model_task – The specific task type for the model (e.g., TaskType.CAUSAL_LM). Defaults to TaskType.AUTO_BIND.

  • auto_shard_model – If True, automatically shards the loaded model and optimizer state based on the provided sharding configuration. Defaults to False.

  • partition_rules – Optional tuple of partition rules (regex, PartitionSpec) to explicitly define sharding. Defaults to None (uses model config).

  • quantization_config – Quantization configuration. Pass None to disable quantization.

  • quantize_tensors – If True, applies quantization to the loaded tensors. Defaults to True.

  • verbose – If True, logs detailed information during loading. Defaults to True.

  • **kwargs – Additional keyword arguments passed directly to the underlying EasyDeLBaseModule.from_pretrained method.

Returns

An EasyDeLState instance containing the loaded model, optimizer state (if found and loaded), and associated configuration.

Raises
  • FileNotFoundError – If the load_directory or essential files within it (like configuration or model weights) are not found.

  • ValueError – If there are inconsistencies in the provided arguments or loaded configuration.

merge(tree) EasyDeLBaseModule[source]#

Merges a given state tree (usually parameters) with the graph definition and other state components to reconstruct the full model module.

Parameters

tree – The pytree (e.g., nn.GraphState) containing the parameters to merge.

Returns

The reconstructed model module.

Return type

EasyDeLBaseModule

merge_to_state(tree) Self[source]#

Creates a new EasyDeLState by replacing the current graphstate with the provided tree.

Parameters

tree – The pytree (e.g., nn.GraphState) containing the new parameters.

Returns

A new state object with the updated graphstate.

Return type

EasyDeLState

property mesh#

Gets the JAX device mesh from the model.

Returns

The JAX device mesh used for sharding.

property model: EasyDeLBaseModule#

Reconstructs and returns the full EasyDeL model module from the state components.

Returns

The model module instance.

Return type

EasyDeLBaseModule

opt_state: optax.OptState | None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

save_optimizer(save_directory: str | eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath, float_dtype: numpy.dtype | None = None, checkpointer: eformer.serialization.checkpointer.Checkpointer | None = None, step: int | None = None)[source]#

Saves the optimizer state to a directory.

Saves optimizer state using AsyncCheckpointManager with optional dtype conversion. The state is saved as a pytree with metadata including the current step.

Parameters
  • save_directory – Directory path to save the optimizer state.

  • float_dtype – Optional dtype to convert floating-point values to before saving. Useful for reducing checkpoint size.

save_state(save_directory: str | os.PathLike | eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath, float_dtype: numpy.dtype | None = None, save_optimizer: bool = True, step: int | None = None)[source]#

Saves the entire EasyDeLState to a directory.

This includes saving the model parameters (using model.save_pretrained) and optionally the optimizer state.

Parameters
  • save_directory (tp.Union[str, os.PathLike]) – The directory to save the state to.

  • float_dtype (tp.Optional[jnp.dtype]) – Optional dtype to cast floating-point parameters to before saving. Defaults to None.

  • verbose (bool) – If True, logs information during saving. Defaults to True.

  • mismatch_allowed (bool) – Passed to model.save_pretrained, allows saving even if the model structure differs slightly from expected. Defaults to True.

  • save_optimizer (bool) – If True, saves the optimizer state. Defaults to True.

shard_model(partition_rules: PartitionLike = None, mesh: Mesh | None = None) Self[source]#

Shards the model parameters (graphstate and graphother) based on partition rules.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses model config rules. Defaults to None.

  • mesh (tp.Optional[Mesh], optional) – The JAX device mesh to shard across. If None, uses model’s mesh. Defaults to None.

Returns

A new state object with sharded graphstate and graphother.

Return type

EasyDeLState

shard_optimizer_state(opt_state: tp.Any | None = None, partition_rules: PartitionLike = None) tp.Any[source]#

Applies sharding to the optimizer state based on partition rules.

Parameters
  • opt_state (tp.Optional[tp.Any]) – The optimizer state pytree to shard. If None, uses self.opt_state. Defaults to None.

  • partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses rules from the model’s config. Defaults to None.

Returns

A new state object with the sharded opt_state.

Return type

EasyDeLState

Raises

ValueError – If optimizer state is not initialized (neither opt_state argument nor self.opt_state is available).

shard_state(partition_rules: PartitionLike = None, mesh: Mesh = None) Self[source]#

Shards the entire state (model parameters and optimizer state) based on partition rules.

This is a convenience method that calls shard_model and shard_optimizer_state.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses rules from the model’s config. Defaults to None.

  • mesh (Mesh, optional) – The JAX device mesh to shard across. If None, uses model’s mesh. Defaults to None.

Returns

A new state object with both model and optimizer states sharded.

Return type

EasyDeLState

shard_with_shape(shape) Self[source]#

Applies sharding constraints to the entire state based on a reference shape pytree.

This method takes a pytree shape which has the same structure as the EasyDeLState but contains sharding annotations (e.g., NamedSharding) instead of actual array data. It applies these shardings as constraints to the corresponding arrays in the current state.

Parameters

shape – A pytree with the same structure as self, containing sharding annotations.

Returns

A new state object with sharding constraints applied.

Return type

EasyDeLState

property shardings#

Retrieves the sharding annotations (e.g., NamedSharding) for all components of the EasyDeLState pytree.

Returns

A pytree with the same structure as self, containing sharding annotations or None for components without sharding.

property size: int#

Calculates the total size in bytes of the model parameters (graphstate) and the optimizer state (opt_state).

Returns

The total size in bytes.

Return type

int

step: int | jax.Array#
tx: optax.GradientTransformation#