easydel.trainers.training_utils#

easydel.trainers.training_utils.make_assertions_and_get_sizes(batch: Dict, gradient_accumulation_steps: int, batch_partition_spec: Optional[PartitionSpec] = None) Tuple[int, int, PartitionSpec][source]#

Validates the input parameters and computes the batch size, minibatch size, and batch partition specification. :param batch: A dictionary containing the batch data. The batch size is inferred from the first element’s shape. :type batch: tp.Dict :param gradient_accumulation_steps: The number of gradient accumulation steps. Must be greater than 0. :type gradient_accumulation_steps: int :param batch_partition_spec: The partition specification for the batch. Defaults to None. :type batch_partition_spec: tp.Optional[PartitionSpec], optional

Returns

A tuple containing:
  • batch_size (int): The size of the batch.

  • minibatch_size (int): The size of the minibatch.

  • batch_partition_spec (PartitionSpec): The partition specification for the batch.

Return type

tp.Tuple[int, int, PartitionSpec]

Raises
  • ValueError – If gradient_accumulation_steps is not greater than 0.

  • ValueError – If the batch size is not divisible by the gradient accumulation steps.

easydel.trainers.training_utils.minibatch_call(state: EasyDeLState, batch: Dict, minibatch_size: int, grad_fn: Callable[[Array, Dict], Tuple[Array, LossMetrics]]) Tuple[Array, LossMetrics][source]#

Processes batch in smaller chunks for gradient accumulation using jax.lax.scan. Uses eval_shape to initialize accumulator structures efficiently.

easydel.trainers.training_utils.update_metrics(metrics: LossMetrics, learning_rate_fn: Callable, step: int | jax.Array, gradients: Optional[Array]) LossMetrics[source]#

Updates the given metrics with the current learning rate and gradient norms.

Parameters
  • metrics (LossMetrics) – An instance of LossMetrics to be updated.

  • learning_rate_fn (tp.Callable) – A callable that returns the learning rate given the current step.

  • step (int | jax.Array) – The current training step.

  • gradients (Optional(jax.Array)) – The gradients to compute norms from.

Returns

The updated metrics with learning rate and gradient norms.

Return type

LossMetrics

easydel.trainers.training_utils.update_state_respectfully(state: EasyDeLState, gradients: Array, loss_config: LossConfig, metrics: LossMetrics) EasyDeLState[source]#

Updates the state of the model respectfully based on the provided gradients, loss configuration, and metrics.

Parameters
  • state (EasyDeLState) – The current state of the model.

  • gradients (jax.Array) – The gradients to be applied to the model’s parameters.

  • loss_config (LossConfig) – Configuration for the loss, including conditions for breaking on NaN values.

  • metrics (LossMetrics) – Metrics containing the loss value.

Returns

The updated state of the model.

Return type

EasyDeLState