easydel.infra.base_state#
- class easydel.infra.base_state.EasyDeLState(step: int | jax.Array, graphdef: nn.GraphDef, graphstate: nn.GraphState, graphother: nn.GraphState, tx: optax.GradientTransformation, opt_state: tp.Optional[optax.OptState], apply_fn: tp.Optional[tp.Callable] = None)[source]#
Bases:
PyTreeNodeRepresents the state of an EasyDeL model during training or inference.
This class encapsulates the modelโs parameters, optimizer state, training step, and potentially other metadata. It provides methods for applying gradients, managing sharding, saving, and loading the state.
- graphdef#
The definition of the modelโs computation graph (structure).
- Type
nn.GraphDef
- graphstate#
The state of the modelโs parameters.
- Type
nn.GraphState
- graphother#
The state of non-parameter variables within the model.
- Type
nn.GraphState
- tx#
The optimizer transformation (e.g., AdamW, SGD). Marked as a non-pytree node.
- Type
optax.GradientTransformation
- opt_state#
The state of the optimizer (e.g., moments). Marked as a pytree node.
- Type
tp.Optional[optax.OptState]
- apply_fn#
A function to apply the model (often model.__call__). Typically not directly part of the state but can be associated.
- Type
tp.Optional[tp.Callable]
- apply_fn: tp.Optional[tp.Callable] = None#
- apply_gradients(*, grads)[source]#
Updates the modelโs parameters and optimizer state based on calculated gradients.
- Parameters
grads โ A pytree matching the structure of self.graphstate containing the gradients.
- Returns
- A new state object with the updated parameters (graphstate),
optimizer state (opt_state), and incremented step count.
- Return type
- Raises
AssertionError โ If opt_state or tx is not initialized.
- classmethod create(*, step: tp.Optional[int] = None, graphdef: tp.Optional[nn.GraphDef] = None, graphstate: tp.Optional[nn.GraphState] = None, graphother: tp.Optional[nn.GraphState] = None, model: tp.Optional[nn.Module] = None, tx: tp.Optional[optax.GradientTransformation] = None, opt_state: tp.Optional[optax.OptState] = None, init_opt_state: bool = False) EasyDeLState[source]#
Creates a new EasyDeLState instance.
This class method provides a flexible way to initialize the state, either from an existing nn.Module or by providing the graph components (graphdef, graphstate, graphother) directly. It also handles optimizer state initialization.
- Parameters
step (tp.Optional[int]) โ The initial training step. Defaults to 0.
graphdef (tp.Optional[nn.GraphDef]) โ The modelโs graph definition.
graphstate (tp.Optional[nn.GraphState]) โ The modelโs parameter state.
graphother (tp.Optional[nn.GraphState]) โ The modelโs non-parameter state.
model (tp.Optional[nn.Module]) โ An EasyDeL module instance. If provided, graphdef, graphstate, and graphother are derived from it. Cannot be provided simultaneously with graph components.
tx (tp.Optional[optax.GradientTransformation]) โ The optimizer transformation.
opt_state (tp.Optional[optax.OptState]) โ The initial optimizer state. Cannot be provided if init_opt_state is True.
init_opt_state (bool) โ If True, initializes the optimizer state using tx.init(graphstate). Requires tx to be provided. Defaults to False.
- Returns
A new instance of the state.
- Return type
- Raises
ValueError โ If model and graph components are provided simultaneously.
ValueError โ If graph components are provided partially.
ValueError โ If init_opt_state is True and opt_state is also provided.
ValueError โ If init_opt_state is True but tx is not provided.
- gather_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#
Gathers the model parameters (graphstate and graphother) from distributed devices.
- Parameters
partition_rules (PartitionLike, optional) โ Partitioning rules used for the original sharding. If None, uses model config rules. Defaults to None.
mesh (tp.Optional[Mesh], optional) โ The JAX device mesh to gather from. If None, uses modelโs mesh. Defaults to None.
- Returns
A new state object with gathered graphstate and graphother.
- Return type
- gather_optimizer_state(partition_rules=None)[source]#
Gathers the optimizer state from potentially distributed devices.
- Parameters
partition_rules (PartitionLike, optional) โ Partitioning rules used to determine how the state was sharded. If None, uses rules from the modelโs config. Defaults to None.
- Returns
A new state object with the gathered opt_state.
- Return type
- Raises
AssertionError โ If opt_state is not initialized.
- gather_state()[source]#
Gathers the entire state (model parameters and optimizer state) from distributed devices.
This is a convenience method that calls gather_model and gather_optimizer_state.
- Returns
A new state object with both model and optimizer states gathered.
- Return type
- graphdef: nn.GraphDef#
- graphother: nn.GraphState#
- graphstate: nn.GraphState#
- init_tx(tx: GradientTransformation, partition_rules: Any = None) EasyDeLState[source]#
Initializes the optimizer state (opt_state) for the current graphstate using the provided optimizer transformation (tx). It automatically handles sharding based on the modelโs partition rules.
- Parameters
tx (optax.GradientTransformation) โ The optimizer transformation to initialize with.
partition_rules (PartitionLike, optional) โ Partitioning rules for the optimizer state. If None, uses the rules from the associated modelโs config. Defaults to None.
- Returns
- A new state object with the initialized and potentially sharded
opt_state and the provided tx.
- Return type
- load_optimizer(load_directory: Union[str, PathLike])[source]#
Loads the optimizer state from saved files.
Reads the optimizer state structure from a pickle file (OPTIMIZER_STRUCT_NAME) and the tensor data from a SafeTensors file (OPTIMIZER_NAME) within the specified directory.
- Parameters
load_directory (tp.Union[str, os.PathLike]) โ The directory containing the saved optimizer state files.
- Returns
A new state object with the loaded opt_state.
- Return type
- Raises
FileNotFoundError โ If the required optimizer files are not found.
Exception โ If any error occurs during loading or deserialization.
- classmethod load_state(load_directory: ~typing.Union[str, ~os.PathLike], device: ~typing.Optional[~jaxlib.xla_extension.Device] = 'cpu', dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~typing.Optional[~jax._src.lax.lax.Precision] = None, sharding_axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, sharding_axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: ~typing.Optional[~eformer.escale.partition.manager.PartitionAxis] = None, shard_attention_computation: bool = True, shard_fns: ~typing.Optional[~typing.Union[~typing.Mapping[tuple, ~typing.Callable], dict]] = None, backend: ~typing.Optional[~typing.Any] = None, platform: ~typing.Optional[~typing.Any] = None, config_kwargs: ~typing.Optional[~typing.Any] = None, model_task: ~easydel.infra.factory.TaskType = TaskType.AUTO_BIND, auto_shard_model: bool = False, partition_rules: ~typing.Optional[~typing.Tuple[~typing.Tuple[str, ~jax._src.partition_spec.PartitionSpec], ...]] = None, quantization_platform: ~typing.Optional[~typing.Any] = None, quantization_method: ~typing.Optional[~typing.Any] = None, quantization_block_size: int = 128, quantization_pattern: ~typing.Optional[str] = None, quantize_tensors: bool = True, verbose: bool = True, **kwargs)[source]#
Loads an EasyDeLState from a saved checkpoint directory.
This class method reconstructs the model configuration, loads the model parameters, and optionally loads the optimizer state from files saved previously using save_state. It handles various configurations for device placement, data types, sharding, and quantization.
- Parameters
load_directory โ Path to the directory containing the saved state (configuration, model weights, and potentially optimizer state).
device โ The JAX device (e.g., โcpuโ, โgpuโ, โtpuโ) to load the model onto. Defaults to โcpuโ.
dtype โ The data type to use for computation (e.g., jax.numpy.float32). Defaults to jax.numpy.float32.
param_dtype โ The data type for the model parameters (e.g., jax.numpy.bfloat16). Defaults to jax.numpy.float32.
precision โ The JAX precision level (e.g., jax.lax.Precision.HIGHEST). Defaults to None.
sharding_axis_dims โ A sequence defining the dimensions of the device mesh for sharding (e.g., (1, -1, 1, 1)). Defaults to (1, -1, 1, 1).
sharding_dcn_axis_dims โ Optional sequence for data-centric sharding dimensions. Defaults to None.
sharding_axis_names โ Names corresponding to the sharding axes (e.g., (โdpโ, โfsdpโ, โtpโ, โspโ)). Defaults to (โdpโ, โfsdpโ, โtpโ, โspโ).
partition_axis โ Configuration object for partitioning specific axes. Defaults to None.
shard_attention_computation โ If True, shards the attention computation across devices. Defaults to True.
shard_fns โ Optional mapping of parameter path tuples to custom sharding functions. Defaults to None.
backend โ The backend framework to use (e.g., EasyDeLBackends.JAX). Defaults to None (auto-detected).
platform โ The hardware platform (e.g., EasyDeLPlatforms.TPU). Defaults to None (auto-detected).
config_kwargs โ Optional dictionary of keyword arguments to override in the loaded model configuration. Defaults to None.
model_task โ The specific task type for the model (e.g., TaskType.CAUSAL_LM). Defaults to TaskType.AUTO_BIND.
auto_shard_model โ If True, automatically shards the loaded model and optimizer state based on the provided sharding configuration. Defaults to False.
partition_rules โ Optional tuple of partition rules (regex, PartitionSpec) to explicitly define sharding. Defaults to None (uses model config).
quantization_platform โ Platform for quantization (e.g., EasyDeLPlatforms.TPU). Defaults to None.
quantization_method โ Quantization method (e.g., EasyDeLQuantizationMethods.AQT). Defaults to None.
quantization_block_size โ Block size for quantization methods like GPTQ. Defaults to 128.
quantization_pattern โ Regex pattern to match tensor names for quantization. Defaults to None.
quantize_tensors โ If True, applies quantization to the loaded tensors. Defaults to True.
verbose โ If True, logs detailed information during loading. Defaults to True.
**kwargs โ Additional keyword arguments passed directly to the underlying EasyDeLBaseModule.from_pretrained method.
- Returns
An EasyDeLState instance containing the loaded model, optimizer state (if found and loaded), and associated configuration.
- Raises
FileNotFoundError โ If the load_directory or essential files within it (like configuration or model weights) are not found.
ValueError โ If there are inconsistencies in the provided arguments or loaded configuration.
# Note โ Other exceptions from underlying calls like AutoEasyDeLConfig
# or EasyDeLBaseModule.from_pretrained might also be raised. โ
- merge(tree) Any[source]#
Merges a given state tree (usually parameters) with the graph definition and other state components to reconstruct the full model module.
- Parameters
tree โ The pytree (e.g., nn.GraphState) containing the parameters to merge.
- Returns
The reconstructed model module.
- Return type
- merge_to_state(tree) EasyDeLState[source]#
Creates a new EasyDeLState by replacing the current graphstate with the provided tree.
- Parameters
tree โ The pytree (e.g., nn.GraphState) containing the new parameters.
- Returns
A new state object with the updated graphstate.
- Return type
- property model: Any#
Reconstructs and returns the full EasyDeL model module from the state components.
- Returns
The model module instance.
- Return type
- opt_state: tp.Optional[optax.OptState]#
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- save_state(save_directory: Union[str, PathLike], float_dtype: Optional[dtype] = None, verbose: bool = True, mismatch_allowed: bool = True, save_optimizer: bool = True, enable: Optional[bool] = None)[source]#
Saves the entire EasyDeLState to a directory.
This includes saving the model parameters (using model.save_pretrained) and optionally the optimizer state.
- Parameters
save_directory (tp.Union[str, os.PathLike]) โ The directory to save the state to.
float_dtype (tp.Optional[jax.numpy.dtype]) โ Optional dtype to cast floating-point parameters to before saving. Defaults to None.
verbose (bool) โ If True, logs information during saving. Defaults to True.
mismatch_allowed (bool) โ Passed to model.save_pretrained, allows saving even if the model structure differs slightly from expected. Defaults to True.
save_optimizer (bool) โ If True, saves the optimizer state. Defaults to True.
enable (tp.Optional[bool]) โ If set, controls whether saving happens (True) or is skipped (False). If None, saving typically occurs only on JAX process index 0. Defaults to None.
- shard_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#
Shards the model parameters (graphstate and graphother) based on partition rules.
- Parameters
partition_rules (PartitionLike, optional) โ Partitioning rules. If None, uses model config rules. Defaults to None.
mesh (tp.Optional[Mesh], optional) โ The JAX device mesh to shard across. If None, uses modelโs mesh. Defaults to None.
- Returns
A new state object with sharded graphstate and graphother.
- Return type
- shard_optimizer_state(opt_state: Optional[Any] = None, partition_rules: Any = None) Any[source]#
Applies sharding to the optimizer state based on partition rules.
- Parameters
opt_state (tp.Optional[tp.Any]) โ The optimizer state pytree to shard. If None, uses self.opt_state. Defaults to None.
partition_rules (PartitionLike, optional) โ Partitioning rules. If None, uses rules from the modelโs config. Defaults to None.
- Returns
A new state object with the sharded opt_state.
- Return type
- Raises
ValueError โ If optimizer state is not initialized (neither opt_state argument nor self.opt_state is available).
- shard_state(partition_rules: Any = None) EasyDeLState[source]#
Shards the entire state (model parameters and optimizer state) based on partition rules.
This is a convenience method that calls shard_model and shard_optimizer_state.
- Parameters
partition_rules (PartitionLike, optional) โ Partitioning rules. If None, uses rules from the modelโs config. Defaults to None.
- Returns
A new state object with both model and optimizer states sharded.
- Return type
- shard_with_shape(shape) EasyDeLState[source]#
Applies sharding constraints to the entire state based on a reference shape pytree.
This method takes a pytree shape which has the same structure as the EasyDeLState but contains sharding annotations (e.g., NamedSharding) instead of actual array data. It applies these shardings as constraints to the corresponding arrays in the current state.
- Parameters
shape โ A pytree with the same structure as self, containing sharding annotations.
- Returns
A new state object with sharding constraints applied.
- Return type
- property shardings#
Retrieves the sharding annotations (e.g., NamedSharding) for all components of the EasyDeLState pytree.
- Returns
A pytree with the same structure as self, containing sharding annotations or None for components without sharding.
- property size: int#
Calculates the total size in bytes of the model parameters (graphstate) and the optimizer state (opt_state).
- Returns
The total size in bytes.
- Return type
int
- tx: optax.GradientTransformation#