easydel.trainers.trainer._fn#
- easydel.trainers.trainer._fn.evaluation_step(state: EasyDeLState, batch: Mapping[str, Array], loss_config: Optional[LossConfig] = None, partition_spec: Optional[PartitionSpec] = None) Tuple[Any, LossMetrics][source]#
Performs a single evaluation step by computing loss metrics for the input batch.
The function determines the required partitioning for the batch, applies sharding constraints, and defines an inner loss function. This inner function merges the current state with the graph state, sets the model to evaluation mode, and computes loss and metrics via the model’s compute_loss method. The computed LossMetrics are then returned.
- Parameters
state (EasyDeLState) – The current model state.
batch (tp.Mapping[str, jax.Array]) – A mapping of input arrays for evaluation.
loss_config (tp.Optional[LossConfig], optional) – Configuration for loss computation. Defaults to None.
partition_spec (tp.Optional[PartitionSpec], optional) – Specification for sharding the batch. Defaults to None.
- Returns
- A tuple containing:
(Any): An additional output from loss computation (if any).
LossMetrics: The computed loss metrics for the evaluation batch.
- Return type
tp.Tuple[tp.Any, LossMetrics]
- easydel.trainers.trainer._fn.training_step(state: EasyDeLState, batch: Mapping[str, Array], loss_config: Optional[LossConfig] = None, learning_rate_fn: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]] = None, partition_spec: Optional[PartitionSpec] = None, gradient_accumulation_steps: int = 1) Tuple[EasyDeLState, LossMetrics][source]#
Performs a single training step by computing gradients via minibatch processing, updating the model state, and returning updated state and loss metrics.
The function first determines the batch and minibatch sizes using assertions. It then applies sharding constraints to the batch. The loss function is defined as an inner function that merges the current model state with an updated tree, prepares the inputs, and computes the loss using the model’s compute_loss method. Gradients are computed using jax.value_and_grad over minibatches. The state is updated respectfully using the computed gradients and updated metrics.
- Parameters
state (EasyDeLState) – The current model state, which includes parameters and model graph.
batch (tp.Mapping[str, jax.Array]) – A mapping of input arrays for the current batch.
loss_config (tp.Optional[LossConfig], optional) – Configuration settings for the loss computation. Defaults to None.
learning_rate_fn (optax.Schedule, optional) – A schedule function for the learning rate. Defaults to None.
partition_spec (tp.Optional[PartitionSpec], optional) – Specification for data sharding. Defaults to None.
gradient_accumulation_steps (int, optional) – Number of steps over which to accumulate gradients. Defaults to 1.
- Returns
- A tuple containing:
The updated EasyDeLState after applying gradients.
LossMetrics containing computed loss and other related metrics.
- Return type
tp.Tuple[EasyDeLState, LossMetrics]