Source code for easydel.trainers.training_utils

import typing as tp

import jax
from jax import lax
from jax import numpy as jnp
from jax import tree_util as tu
from jax.sharding import PartitionSpec

from easydel.infra.base_state import EasyDeLState
from easydel.infra.loss_utils import LossConfig, LossMetrics
from easydel.utils.helpers import check_bool_flag

SCAN_TRAINER = check_bool_flag("SCAN_TRAINER")
FAST_COMPILE = check_bool_flag("FAST_COMPILE")


[docs]def make_assertions_and_get_sizes( batch: tp.Dict, gradient_accumulation_steps: int, batch_partition_spec: tp.Optional[PartitionSpec] = None, ) -> tp.Tuple[int, int, PartitionSpec]: """ Validates the input parameters and computes the batch size, minibatch size, and batch partition specification. Args: batch (tp.Dict): A dictionary containing the batch data. The batch size is inferred from the first element's shape. gradient_accumulation_steps (int): The number of gradient accumulation steps. Must be greater than 0. batch_partition_spec (tp.Optional[PartitionSpec], optional): The partition specification for the batch. Defaults to None. Returns: tp.Tuple[int, int, PartitionSpec]: A tuple containing: - batch_size (int): The size of the batch. - minibatch_size (int): The size of the minibatch. - batch_partition_spec (PartitionSpec): The partition specification for the batch. Raises: ValueError: If `gradient_accumulation_steps` is not greater than 0. ValueError: If the batch size is not divisible by the gradient accumulation steps. """ batch_size = batch[list(batch.keys())[0]].shape[0] minibatch_size = batch_size // gradient_accumulation_steps if not gradient_accumulation_steps > 0: ValueError("`gradient_accumulation_steps` must be greater than 0.") if minibatch_size * gradient_accumulation_steps != batch_size: raise ValueError("Batch size must be divisible by gradient accumulation steps.") if batch_partition_spec is None: batch_partition_spec = PartitionSpec(("dp", "fsdp"), "sp") return batch_size, minibatch_size, batch_partition_spec
[docs]def update_metrics( metrics: LossMetrics, learning_rate_fn: tp.Callable, step: int | jax.Array, gradients: tp.Optional[jax.Array], ) -> LossMetrics: """ Updates the given metrics with the current learning rate and gradient norms. Args: metrics (LossMetrics): An instance of LossMetrics to be updated. learning_rate_fn (tp.Callable): A callable that returns the learning rate given the current step. step (int | jax.Array): The current training step. gradients (Optional(jax.Array)): The gradients to compute norms from. Returns: LossMetrics: The updated metrics with learning rate and gradient norms. """ if learning_rate_fn is not None: metrics.learning_rate = learning_rate_fn(step) if gradients is not None: grad_norms = tu.tree_map(jnp.linalg.norm, gradients) metrics.max_grad_norm = tu.tree_reduce(jnp.maximum, grad_norms) grad_size = tu.tree_reduce(jnp.add, tu.tree_map(jnp.size, grad_norms)) grad_sum = tu.tree_reduce(jnp.add, tu.tree_map(jnp.sum, grad_norms)) metrics.mean_grad_norm = grad_sum / grad_size metrics.grad_norms = grad_norms return metrics
[docs]def update_state_respectfully( state: EasyDeLState, gradients: jax.Array, loss_config: LossConfig, metrics: LossMetrics, ) -> EasyDeLState: """ Updates the state of the model respectfully based on the provided gradients, loss configuration, and metrics. Args: state (EasyDeLState): The current state of the model. gradients (jax.Array): The gradients to be applied to the model's parameters. loss_config (LossConfig): Configuration for the loss, including conditions for breaking on NaN values. metrics (LossMetrics): Metrics containing the loss value. Returns: EasyDeLState: The updated state of the model. """ if FAST_COMPILE: return state.apply_gradients(grads=gradients) else: def update_fn(args): state, gradients = args return state.apply_gradients(grads=gradients) def skip_fn(args): state, _ = args return state should_update = True if loss_config is not None: should_update = lax.cond( loss_config.break_on_nan, lambda x: lax.cond( jnp.isnan(metrics.loss), lambda _: False, lambda _: True, None, ), lambda x: True, None, ) state = lax.cond(should_update, update_fn, skip_fn, (state, gradients)) return state
[docs]def minibatch_call( state: EasyDeLState, batch: tp.Dict, minibatch_size: int, grad_fn: tp.Callable[[jax.Array, tp.Dict], tp.Tuple[jax.Array, LossMetrics]], ) -> tp.Tuple[jax.Array, LossMetrics]: """ Processes batch in smaller chunks for gradient accumulation using jax.lax.scan. Uses eval_shape to initialize accumulator structures efficiently. """ num_accum_steps = len(next(iter(batch.values()))) // minibatch_size if num_accum_steps > 1: def reshape_to_minibatches(arr): """Reshape the batch into minibatches for accumulation.""" batch_shape = (num_accum_steps, minibatch_size) + arr.shape[1:] return jnp.reshape(arr, batch_shape) batch = jax.tree_util.tree_map(reshape_to_minibatches, batch) (_, metrics_shape), grads_shape = jax.eval_shape( lambda: grad_fn( state.graphstate, jax.tree_util.tree_map(lambda x: x[0], batch), ) ) init_acc = { "grads": jax.tree_util.tree_map( lambda x: jnp.zeros(x.shape, x.dtype), grads_shape ), "metrics": jax.tree_util.tree_map( lambda x: jnp.zeros(x.shape, x.dtype), metrics_shape ), } def accumulate_gradients(acc, minibatch): """Accumulate gradients and metrics for each minibatch.""" (_, step_aux), step_grads = grad_fn(state.graphstate, minibatch) new_acc = { "grads": jax.tree_util.tree_map(jnp.add, acc["grads"], step_grads), "metrics": jax.tree_util.tree_map(jnp.add, acc["metrics"], step_aux), } return new_acc, step_aux final_acc, aux = jax.lax.scan( accumulate_gradients, init_acc, batch, length=num_accum_steps, ) gradients = jax.tree_util.tree_map( lambda x: x / num_accum_steps, final_acc["grads"] ) metrics = jax.tree_util.tree_map( lambda x: x / num_accum_steps, final_acc["metrics"] ) else: (_, metrics), gradients = grad_fn(state.graphstate, batch) return gradients, metrics