easydel.infra.base_state#

class easydel.infra.base_state.EasyDeLState(step: int | jax.Array, graphdef: nn.GraphDef, graphstate: nn.GraphState, graphother: nn.GraphState, tx: optax.GradientTransformation, opt_state: tp.Optional[optax.OptState], apply_fn: tp.Optional[tp.Callable] = None)[source]#

Bases: PyTreeNode

Represents the state of an EasyDeL model during training or inference.

This class encapsulates the modelโ€™s parameters, optimizer state, training step, and potentially other metadata. It provides methods for applying gradients, managing sharding, saving, and loading the state.

step#

The current training step count.

Type

int | jax.Array

graphdef#

The definition of the modelโ€™s computation graph (structure).

Type

nn.GraphDef

graphstate#

The state of the modelโ€™s parameters.

Type

nn.GraphState

graphother#

The state of non-parameter variables within the model.

Type

nn.GraphState

tx#

The optimizer transformation (e.g., AdamW, SGD). Marked as a non-pytree node.

Type

optax.GradientTransformation

opt_state#

The state of the optimizer (e.g., moments). Marked as a pytree node.

Type

tp.Optional[optax.OptState]

apply_fn#

A function to apply the model (often model.__call__). Typically not directly part of the state but can be associated.

Type

tp.Optional[tp.Callable]

apply_fn: tp.Optional[tp.Callable] = None#
apply_gradients(*, grads)[source]#

Updates the modelโ€™s parameters and optimizer state based on calculated gradients.

Parameters

grads โ€“ A pytree matching the structure of self.graphstate containing the gradients.

Returns

A new state object with the updated parameters (graphstate),

optimizer state (opt_state), and incremented step count.

Return type

EasyDeLState

Raises

AssertionError โ€“ If opt_state or tx is not initialized.

classmethod create(*, step: tp.Optional[int] = None, graphdef: tp.Optional[nn.GraphDef] = None, graphstate: tp.Optional[nn.GraphState] = None, graphother: tp.Optional[nn.GraphState] = None, model: tp.Optional[nn.Module] = None, tx: tp.Optional[optax.GradientTransformation] = None, opt_state: tp.Optional[optax.OptState] = None, init_opt_state: bool = False) EasyDeLState[source]#

Creates a new EasyDeLState instance.

This class method provides a flexible way to initialize the state, either from an existing nn.Module or by providing the graph components (graphdef, graphstate, graphother) directly. It also handles optimizer state initialization.

Parameters
  • step (tp.Optional[int]) โ€“ The initial training step. Defaults to 0.

  • graphdef (tp.Optional[nn.GraphDef]) โ€“ The modelโ€™s graph definition.

  • graphstate (tp.Optional[nn.GraphState]) โ€“ The modelโ€™s parameter state.

  • graphother (tp.Optional[nn.GraphState]) โ€“ The modelโ€™s non-parameter state.

  • model (tp.Optional[nn.Module]) โ€“ An EasyDeL module instance. If provided, graphdef, graphstate, and graphother are derived from it. Cannot be provided simultaneously with graph components.

  • tx (tp.Optional[optax.GradientTransformation]) โ€“ The optimizer transformation.

  • opt_state (tp.Optional[optax.OptState]) โ€“ The initial optimizer state. Cannot be provided if init_opt_state is True.

  • init_opt_state (bool) โ€“ If True, initializes the optimizer state using tx.init(graphstate). Requires tx to be provided. Defaults to False.

Returns

A new instance of the state.

Return type

EasyDeLState

Raises
  • ValueError โ€“ If model and graph components are provided simultaneously.

  • ValueError โ€“ If graph components are provided partially.

  • ValueError โ€“ If init_opt_state is True and opt_state is also provided.

  • ValueError โ€“ If init_opt_state is True but tx is not provided.

gather_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#

Gathers the model parameters (graphstate and graphother) from distributed devices.

Parameters
  • partition_rules (PartitionLike, optional) โ€“ Partitioning rules used for the original sharding. If None, uses model config rules. Defaults to None.

  • mesh (tp.Optional[Mesh], optional) โ€“ The JAX device mesh to gather from. If None, uses modelโ€™s mesh. Defaults to None.

Returns

A new state object with gathered graphstate and graphother.

Return type

EasyDeLState

gather_optimizer_state(partition_rules=None)[source]#

Gathers the optimizer state from potentially distributed devices.

Parameters

partition_rules (PartitionLike, optional) โ€“ Partitioning rules used to determine how the state was sharded. If None, uses rules from the modelโ€™s config. Defaults to None.

Returns

A new state object with the gathered opt_state.

Return type

EasyDeLState

Raises

AssertionError โ€“ If opt_state is not initialized.

gather_state()[source]#

Gathers the entire state (model parameters and optimizer state) from distributed devices.

This is a convenience method that calls gather_model and gather_optimizer_state.

Returns

A new state object with both model and optimizer states gathered.

Return type

EasyDeLState

graphdef: nn.GraphDef#
graphother: nn.GraphState#
graphstate: nn.GraphState#
init_tx(tx: GradientTransformation, partition_rules: Any = None) EasyDeLState[source]#

Initializes the optimizer state (opt_state) for the current graphstate using the provided optimizer transformation (tx). It automatically handles sharding based on the modelโ€™s partition rules.

Parameters
  • tx (optax.GradientTransformation) โ€“ The optimizer transformation to initialize with.

  • partition_rules (PartitionLike, optional) โ€“ Partitioning rules for the optimizer state. If None, uses the rules from the associated modelโ€™s config. Defaults to None.

Returns

A new state object with the initialized and potentially sharded

opt_state and the provided tx.

Return type

EasyDeLState

load_optimizer(load_directory: Union[str, PathLike])[source]#

Loads the optimizer state from saved files.

Reads the optimizer state structure from a pickle file (OPTIMIZER_STRUCT_NAME) and the tensor data from a SafeTensors file (OPTIMIZER_NAME) within the specified directory.

Parameters

load_directory (tp.Union[str, os.PathLike]) โ€“ The directory containing the saved optimizer state files.

Returns

A new state object with the loaded opt_state.

Return type

EasyDeLState

Raises
  • FileNotFoundError โ€“ If the required optimizer files are not found.

  • Exception โ€“ If any error occurs during loading or deserialization.

classmethod load_state(load_directory: ~typing.Union[str, ~os.PathLike], device: ~typing.Optional[~jaxlib.xla_extension.Device] = 'cpu', dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~typing.Optional[~jax._src.lax.lax.Precision] = None, sharding_axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, sharding_axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: ~typing.Optional[~eformer.escale.partition.manager.PartitionAxis] = None, shard_attention_computation: bool = True, shard_fns: ~typing.Optional[~typing.Union[~typing.Mapping[tuple, ~typing.Callable], dict]] = None, backend: ~typing.Optional[~typing.Any] = None, platform: ~typing.Optional[~typing.Any] = None, config_kwargs: ~typing.Optional[~typing.Any] = None, model_task: ~easydel.infra.factory.TaskType = TaskType.AUTO_BIND, auto_shard_model: bool = False, partition_rules: ~typing.Optional[~typing.Tuple[~typing.Tuple[str, ~jax._src.partition_spec.PartitionSpec], ...]] = None, quantization_platform: ~typing.Optional[~typing.Any] = None, quantization_method: ~typing.Optional[~typing.Any] = None, quantization_block_size: int = 128, quantization_pattern: ~typing.Optional[str] = None, quantize_tensors: bool = True, verbose: bool = True, **kwargs)[source]#

Loads an EasyDeLState from a saved checkpoint directory.

This class method reconstructs the model configuration, loads the model parameters, and optionally loads the optimizer state from files saved previously using save_state. It handles various configurations for device placement, data types, sharding, and quantization.

Parameters
  • load_directory โ€“ Path to the directory containing the saved state (configuration, model weights, and potentially optimizer state).

  • device โ€“ The JAX device (e.g., โ€˜cpuโ€™, โ€˜gpuโ€™, โ€˜tpuโ€™) to load the model onto. Defaults to โ€˜cpuโ€™.

  • dtype โ€“ The data type to use for computation (e.g., jax.numpy.float32). Defaults to jax.numpy.float32.

  • param_dtype โ€“ The data type for the model parameters (e.g., jax.numpy.bfloat16). Defaults to jax.numpy.float32.

  • precision โ€“ The JAX precision level (e.g., jax.lax.Precision.HIGHEST). Defaults to None.

  • sharding_axis_dims โ€“ A sequence defining the dimensions of the device mesh for sharding (e.g., (1, -1, 1, 1)). Defaults to (1, -1, 1, 1).

  • sharding_dcn_axis_dims โ€“ Optional sequence for data-centric sharding dimensions. Defaults to None.

  • sharding_axis_names โ€“ Names corresponding to the sharding axes (e.g., (โ€œdpโ€, โ€œfsdpโ€, โ€œtpโ€, โ€œspโ€)). Defaults to (โ€œdpโ€, โ€œfsdpโ€, โ€œtpโ€, โ€œspโ€).

  • partition_axis โ€“ Configuration object for partitioning specific axes. Defaults to None.

  • shard_attention_computation โ€“ If True, shards the attention computation across devices. Defaults to True.

  • shard_fns โ€“ Optional mapping of parameter path tuples to custom sharding functions. Defaults to None.

  • backend โ€“ The backend framework to use (e.g., EasyDeLBackends.JAX). Defaults to None (auto-detected).

  • platform โ€“ The hardware platform (e.g., EasyDeLPlatforms.TPU). Defaults to None (auto-detected).

  • config_kwargs โ€“ Optional dictionary of keyword arguments to override in the loaded model configuration. Defaults to None.

  • model_task โ€“ The specific task type for the model (e.g., TaskType.CAUSAL_LM). Defaults to TaskType.AUTO_BIND.

  • auto_shard_model โ€“ If True, automatically shards the loaded model and optimizer state based on the provided sharding configuration. Defaults to False.

  • partition_rules โ€“ Optional tuple of partition rules (regex, PartitionSpec) to explicitly define sharding. Defaults to None (uses model config).

  • quantization_platform โ€“ Platform for quantization (e.g., EasyDeLPlatforms.TPU). Defaults to None.

  • quantization_method โ€“ Quantization method (e.g., EasyDeLQuantizationMethods.AQT). Defaults to None.

  • quantization_block_size โ€“ Block size for quantization methods like GPTQ. Defaults to 128.

  • quantization_pattern โ€“ Regex pattern to match tensor names for quantization. Defaults to None.

  • quantize_tensors โ€“ If True, applies quantization to the loaded tensors. Defaults to True.

  • verbose โ€“ If True, logs detailed information during loading. Defaults to True.

  • **kwargs โ€“ Additional keyword arguments passed directly to the underlying EasyDeLBaseModule.from_pretrained method.

Returns

An EasyDeLState instance containing the loaded model, optimizer state (if found and loaded), and associated configuration.

Raises
  • FileNotFoundError โ€“ If the load_directory or essential files within it (like configuration or model weights) are not found.

  • ValueError โ€“ If there are inconsistencies in the provided arguments or loaded configuration.

  • # Note โ€“ Other exceptions from underlying calls like AutoEasyDeLConfig

  • # or EasyDeLBaseModule.from_pretrained might also be raised. โ€“

merge(tree) Any[source]#

Merges a given state tree (usually parameters) with the graph definition and other state components to reconstruct the full model module.

Parameters

tree โ€“ The pytree (e.g., nn.GraphState) containing the parameters to merge.

Returns

The reconstructed model module.

Return type

EasyDeLBaseModule

merge_to_state(tree) EasyDeLState[source]#

Creates a new EasyDeLState by replacing the current graphstate with the provided tree.

Parameters

tree โ€“ The pytree (e.g., nn.GraphState) containing the new parameters.

Returns

A new state object with the updated graphstate.

Return type

EasyDeLState

property model: Any#

Reconstructs and returns the full EasyDeL model module from the state components.

Returns

The model module instance.

Return type

EasyDeLBaseModule

opt_state: tp.Optional[optax.OptState]#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

save_state(save_directory: Union[str, PathLike], float_dtype: Optional[dtype] = None, verbose: bool = True, mismatch_allowed: bool = True, save_optimizer: bool = True, enable: Optional[bool] = None)[source]#

Saves the entire EasyDeLState to a directory.

This includes saving the model parameters (using model.save_pretrained) and optionally the optimizer state.

Parameters
  • save_directory (tp.Union[str, os.PathLike]) โ€“ The directory to save the state to.

  • float_dtype (tp.Optional[jax.numpy.dtype]) โ€“ Optional dtype to cast floating-point parameters to before saving. Defaults to None.

  • verbose (bool) โ€“ If True, logs information during saving. Defaults to True.

  • mismatch_allowed (bool) โ€“ Passed to model.save_pretrained, allows saving even if the model structure differs slightly from expected. Defaults to True.

  • save_optimizer (bool) โ€“ If True, saves the optimizer state. Defaults to True.

  • enable (tp.Optional[bool]) โ€“ If set, controls whether saving happens (True) or is skipped (False). If None, saving typically occurs only on JAX process index 0. Defaults to None.

shard_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#

Shards the model parameters (graphstate and graphother) based on partition rules.

Parameters
  • partition_rules (PartitionLike, optional) โ€“ Partitioning rules. If None, uses model config rules. Defaults to None.

  • mesh (tp.Optional[Mesh], optional) โ€“ The JAX device mesh to shard across. If None, uses modelโ€™s mesh. Defaults to None.

Returns

A new state object with sharded graphstate and graphother.

Return type

EasyDeLState

shard_optimizer_state(opt_state: Optional[Any] = None, partition_rules: Any = None) Any[source]#

Applies sharding to the optimizer state based on partition rules.

Parameters
  • opt_state (tp.Optional[tp.Any]) โ€“ The optimizer state pytree to shard. If None, uses self.opt_state. Defaults to None.

  • partition_rules (PartitionLike, optional) โ€“ Partitioning rules. If None, uses rules from the modelโ€™s config. Defaults to None.

Returns

A new state object with the sharded opt_state.

Return type

EasyDeLState

Raises

ValueError โ€“ If optimizer state is not initialized (neither opt_state argument nor self.opt_state is available).

shard_state(partition_rules: Any = None) EasyDeLState[source]#

Shards the entire state (model parameters and optimizer state) based on partition rules.

This is a convenience method that calls shard_model and shard_optimizer_state.

Parameters

partition_rules (PartitionLike, optional) โ€“ Partitioning rules. If None, uses rules from the modelโ€™s config. Defaults to None.

Returns

A new state object with both model and optimizer states sharded.

Return type

EasyDeLState

shard_with_shape(shape) EasyDeLState[source]#

Applies sharding constraints to the entire state based on a reference shape pytree.

This method takes a pytree shape which has the same structure as the EasyDeLState but contains sharding annotations (e.g., NamedSharding) instead of actual array data. It applies these shardings as constraints to the corresponding arrays in the current state.

Parameters

shape โ€“ A pytree with the same structure as self, containing sharding annotations.

Returns

A new state object with sharding constraints applied.

Return type

EasyDeLState

property shardings#

Retrieves the sharding annotations (e.g., NamedSharding) for all components of the EasyDeLState pytree.

Returns

A pytree with the same structure as self, containing sharding annotations or None for components without sharding.

property size: int#

Calculates the total size in bytes of the model parameters (graphstate) and the optimizer state (opt_state).

Returns

The total size in bytes.

Return type

int

step: int | jax.Array#
tx: optax.GradientTransformation#