easydel.infra.base_state#
- class easydel.infra.base_state.EasyDeLState(step: int | jax.Array, graphdef: nn.GraphDef, graphstate: nn.GraphState, graphother: nn.GraphState, tx: optax.GradientTransformation, opt_state: tp.Optional[optax.OptState], apply_fn: tp.Optional[tp.Callable] = None)[source]#
Bases:
PyTreeNodeEasyDeLState A Snapshot of Your EasyDeL Model
The EasyDeLState class acts like a comprehensive container that holds all the essential information about your EasyDeL model at a given point in time. Think of it as a snapshot of your model. It includes
- apply_fn: tp.Optional[tp.Callable] = None#
- apply_gradients(*, grads)[source]#
Applies gradients to the model parameters and updates the optimizer state. This function is typically called during training to update the model based on the computed gradients.
- Parameters
grads – A dictionary of gradients, where keys correspond to model parameters.
- Returns
An updated EasyDeLState object with modified parameters and optimizer state.
- Return type
- classmethod create(*, step: tp.Optional[int] = None, graphdef: tp.Optional[nn.GraphDef] = None, graphstate: tp.Optional[nn.GraphState] = None, graphother: tp.Optional[nn.GraphState] = None, model: tp.Optional[nn.Module] = None, tx: tp.Optional[optax.GradientTransformation] = None, opt_state: tp.Optional[optax.OptState] = None, init_opt_state: bool = False) EasyDeLState[source]#
Create an instance with flexible initialization options.
- Parameters
step – Optional number of training steps.
graphdef – Optional graph definition.
graphstate – Optional graph state.
graphother – Optional graph *others.
model – Optional neural network module.
tx – Optional gradient transformation.
opt_state – Optional optimizer state.
- Raises
ValueError – If initialization parameters are inconsistent.
- gather_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#
Gathers the model according to the provided partition rules.
- Returns
An updated EasyDeLState object with the gathered model.
- Return type
- gather_state()[source]#
Gathers the entire state.
- Returns
An updated EasyDeLState object with the gathered state.
- Return type
- graphdef: nn.GraphDef#
- graphother: nn.GraphState#
- graphstate: nn.GraphState#
- init_tx(tx: GradientTransformation, partition_rules: Any = None) EasyDeLState[source]#
Initialize the optimizer state with the given gradient transformation.
- Parameters
tx (optax.GradientTransformation) – A gradient transformation to initialize the optimizer state.
partition_rules (Optional[Any], optional) – Rules for partitioning the optimizer state. Defaults to None.
- Returns
An updated EasyDeLState object with the new gradient transformation and sharded optimizer state.
- Return type
- merge_to_state(tree) EasyDeLState[source]#
- property model: Any#
- opt_state: tp.Optional[optax.OptState]#
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- save_state(save_directory: Union[str, PathLike], float_dtype: Optional[dtype] = None, verbose: bool = True, mismatch_allowed: bool = True, save_optimizer: bool = True, enable: Optional[bool] = None)[source]#
- shard_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#
Shards the model according to the provided partition rules.
- Parameters
partition_rules (Optional[Any]) – The partition rules to be used for sharding. If None, the method will use the partition rules from self.model.config.
mesh (Optional[Mesh]) – The mesh to be used for sharding. If None, the method will use the mesh from self.model.
- Returns
An updated EasyDeLState object with the sharded model.
- Return type
- shard_optimizer_state(opt_state: Optional[Any] = None, partition_rules: Any = None) Any[source]#
Shards the optimizer state according to the provided partition rules.
- Parameters
opt_state (Optional[Any]) – The optimizer state to be sharded. If None, the method will use self.opt_state. Raises a ValueError if both opt_state and self.opt_state are None.
partition_rules (Optional[Any]) – The partition rules to be used for sharding. If None, the method will use the partition rules from self.model.config.
- Returns
The sharded optimizer state.
- Return type
Any
- Raises
ValueError – If both opt_state and self.opt_state are None.
- shard_state(partition_rules: Any = None) EasyDeLState[source]#
Shards the entire state, according to the provided partition rules.
- Parameters
partition_rules (Optional[Any]) – The partition rules to be used for sharding. If None, the method will use the partition rules from self.model.config.
- Returns
An updated EasyDeLState object with the sharded state.
- Return type
- shard_with_shape(shape) EasyDeLState[source]#
shard current state with a given shape
- property shardings#
Returns the sharding information for the state.
- Returns
The sharding information.
- Return type
Any
- property size: int#
Calculates the total size of the optimizer state and model graph state.
- Returns
The total size in bytes.
- Return type
int
- tx: optax.GradientTransformation#