easydel.trainers.training_utils#
- easydel.trainers.training_utils.make_assertions_and_get_sizes(batch: Dict, gradient_accumulation_steps: int, batch_partition_spec: Optional[PartitionSpec] = None) Tuple[int, int, PartitionSpec][source]#
Validates the input parameters and computes the batch size, minibatch size, and batch partition specification. :param batch: A dictionary containing the batch data. The batch size is inferred from the first element’s shape. :type batch: tp.Dict :param gradient_accumulation_steps: The number of gradient accumulation steps. Must be greater than 0. :type gradient_accumulation_steps: int :param batch_partition_spec: The partition specification for the batch. Defaults to None. :type batch_partition_spec: tp.Optional[PartitionSpec], optional
- Returns
- A tuple containing:
batch_size (int): The size of the batch.
minibatch_size (int): The size of the minibatch.
batch_partition_spec (PartitionSpec): The partition specification for the batch.
- Return type
tp.Tuple[int, int, PartitionSpec]
- Raises
ValueError – If gradient_accumulation_steps is not greater than 0.
ValueError – If the batch size is not divisible by the gradient accumulation steps.
- easydel.trainers.training_utils.minibatch_call(state: EasyDeLState, batch: Dict, minibatch_size: int, grad_fn: Callable[[Array, Dict], Tuple[Array, LossMetrics]]) Tuple[Array, LossMetrics][source]#
Processes batch in smaller chunks for gradient accumulation using jax.lax.scan. Uses eval_shape to initialize accumulator structures efficiently.
- easydel.trainers.training_utils.update_metrics(metrics: LossMetrics, learning_rate_fn: Callable, step: int | jax.Array, gradients: Optional[Array]) LossMetrics[source]#
Updates the given metrics with the current learning rate and gradient norms.
- Parameters
metrics (LossMetrics) – An instance of LossMetrics to be updated.
learning_rate_fn (tp.Callable) – A callable that returns the learning rate given the current step.
step (int | jax.Array) – The current training step.
gradients (Optional(jax.Array)) – The gradients to compute norms from.
- Returns
The updated metrics with learning rate and gradient norms.
- Return type
- easydel.trainers.training_utils.update_state_respectfully(state: EasyDeLState, gradients: Array, loss_config: LossConfig, metrics: LossMetrics) EasyDeLState[source]#
Updates the state of the model respectfully based on the provided gradients, loss configuration, and metrics.
- Parameters
state (EasyDeLState) – The current state of the model.
gradients (jax.Array) – The gradients to be applied to the model’s parameters.
loss_config (LossConfig) – Configuration for the loss, including conditions for breaking on NaN values.
metrics (LossMetrics) – Metrics containing the loss value.
- Returns
The updated state of the model.
- Return type