easydel.infra.base_state#

class easydel.infra.base_state.EasyDeLState(step: int | jax.Array, graphdef: nn.GraphDef, graphstate: nn.GraphState, graphother: nn.GraphState, tx: optax.GradientTransformation, opt_state: tp.Optional[optax.OptState], apply_fn: tp.Optional[tp.Callable] = None)[source]#

Bases: PyTreeNode

EasyDeLState A Snapshot of Your EasyDeL Model

The EasyDeLState class acts like a comprehensive container that holds all the essential information about your EasyDeL model at a given point in time. Think of it as a snapshot of your model. It includes

apply_fn: tp.Optional[tp.Callable] = None#
apply_gradients(*, grads)[source]#

Applies gradients to the model parameters and updates the optimizer state. This function is typically called during training to update the model based on the computed gradients.

Parameters

grads – A dictionary of gradients, where keys correspond to model parameters.

Returns

An updated EasyDeLState object with modified parameters and optimizer state.

Return type

EasyDeLState

classmethod create(*, step: tp.Optional[int] = None, graphdef: tp.Optional[nn.GraphDef] = None, graphstate: tp.Optional[nn.GraphState] = None, graphother: tp.Optional[nn.GraphState] = None, model: tp.Optional[nn.Module] = None, tx: tp.Optional[optax.GradientTransformation] = None, opt_state: tp.Optional[optax.OptState] = None, init_opt_state: bool = False) EasyDeLState[source]#

Create an instance with flexible initialization options.

Parameters
  • step – Optional number of training steps.

  • graphdef – Optional graph definition.

  • graphstate – Optional graph state.

  • graphother – Optional graph *others.

  • model – Optional neural network module.

  • tx – Optional gradient transformation.

  • opt_state – Optional optimizer state.

Raises

ValueError – If initialization parameters are inconsistent.

gather_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#

Gathers the model according to the provided partition rules.

Returns

An updated EasyDeLState object with the gathered model.

Return type

EasyDeLState

gather_optimizer_state(partition_rules=None)[source]#
gather_state()[source]#

Gathers the entire state.

Returns

An updated EasyDeLState object with the gathered state.

Return type

EasyDeLState

graphdef: nn.GraphDef#
graphother: nn.GraphState#
graphstate: nn.GraphState#
init_tx(tx: GradientTransformation, partition_rules: Any = None) EasyDeLState[source]#

Initialize the optimizer state with the given gradient transformation.

Parameters
  • tx (optax.GradientTransformation) – A gradient transformation to initialize the optimizer state.

  • partition_rules (Optional[Any], optional) – Rules for partitioning the optimizer state. Defaults to None.

Returns

An updated EasyDeLState object with the new gradient transformation and sharded optimizer state.

Return type

EasyDeLState

load_optimizer(load_directory: Union[str, PathLike])[source]#
load_state(load_directory: Union[str, PathLike], verbose: bool = True)[source]#
merge(tree) Any[source]#
merge_to_state(tree) EasyDeLState[source]#
property model: Any#
opt_state: tp.Optional[optax.OptState]#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

save_state(save_directory: Union[str, PathLike], float_dtype: Optional[dtype] = None, verbose: bool = True, mismatch_allowed: bool = True, save_optimizer: bool = True, enable: Optional[bool] = None)[source]#
shard_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#

Shards the model according to the provided partition rules.

Parameters
  • partition_rules (Optional[Any]) – The partition rules to be used for sharding. If None, the method will use the partition rules from self.model.config.

  • mesh (Optional[Mesh]) – The mesh to be used for sharding. If None, the method will use the mesh from self.model.

Returns

An updated EasyDeLState object with the sharded model.

Return type

EasyDeLState

shard_optimizer_state(opt_state: Optional[Any] = None, partition_rules: Any = None) Any[source]#

Shards the optimizer state according to the provided partition rules.

Parameters
  • opt_state (Optional[Any]) – The optimizer state to be sharded. If None, the method will use self.opt_state. Raises a ValueError if both opt_state and self.opt_state are None.

  • partition_rules (Optional[Any]) – The partition rules to be used for sharding. If None, the method will use the partition rules from self.model.config.

Returns

The sharded optimizer state.

Return type

Any

Raises

ValueError – If both opt_state and self.opt_state are None.

shard_state(partition_rules: Any = None) EasyDeLState[source]#

Shards the entire state, according to the provided partition rules.

Parameters

partition_rules (Optional[Any]) – The partition rules to be used for sharding. If None, the method will use the partition rules from self.model.config.

Returns

An updated EasyDeLState object with the sharded state.

Return type

EasyDeLState

shard_with_shape(shape) EasyDeLState[source]#

shard current state with a given shape

property shardings#

Returns the sharding information for the state.

Returns

The sharding information.

Return type

Any

property size: int#

Calculates the total size of the optimizer state and model graph state.

Returns

The total size in bytes.

Return type

int

step: int | jax.Array#
tx: optax.GradientTransformation#