easydel.trainers.distillation_trainer._fn#
Internal functions for knowledge distillation training.
This module contains the core computational functions used by the distillation trainer, including loss functions and training/evaluation step implementations. These functions implement knowledge distillation as described by Hinton et al., where a student model learns to mimic a teacher model’s output distributions.
The distillation process uses temperature scaling to soften probability distributions, allowing the student to learn from the teacher’s confidence across all classes rather than just the hard labels. The loss combines KL divergence between teacher and student distributions with optional supervised learning loss.
All functions are designed for JAX/Flax models and support distributed training.
- easydel.trainers.distillation_trainer._fn.distillation_loss(student_logits: Union[Array, ndarray, bool, number], teacher_logits: Union[Array, ndarray, bool, number], attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, labels: Optional[Union[Array, ndarray, bool, number]] = None, use_hard_labels: bool = False, temperature: float = 4.0, alpha: float = 0.9) tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]][source]#
Compute knowledge distillation loss between student and teacher models.
This function implements the distillation loss as described in Hinton et al.’s “Distilling the Knowledge in a Neural Network”. It combines KL divergence loss between temperature-scaled teacher and student distributions with optional supervised learning loss on hard labels.
- Parameters
student_logits (chex.Array) – Raw logits from the student model. Shape: [batch_size, sequence_length, vocab_size]
teacher_logits (chex.Array) – Raw logits from the teacher model. Shape: [batch_size, sequence_length, vocab_size]
attention_mask (chex.Array | None) – Mask indicating valid tokens. 1 for valid tokens, 0 for padding. Shape: [batch_size, sequence_length]
labels (chex.Array | None) – Ground truth labels for supervised loss. Shape: [batch_size, sequence_length]
use_hard_labels (bool) – Whether to include supervised loss with hard labels. If True, combines distillation loss with cross-entropy loss.
temperature (float) – Temperature for softening probability distributions. Higher values create softer distributions. Default: 4.0
alpha (float) – Weight for distillation loss vs supervised loss. 1.0 means pure distillation, 0.0 means pure supervised. Default: 0.9
- Returns
Scalar loss value combining distillation and optional supervised loss together with the individual components.
- Return type
tuple[chex.Array, dict[str, chex.Array]]
Note
The loss is properly masked to ignore padding tokens when attention_mask is provided. The temperature scaling allows the student to learn from the teacher’s relative confidence across all classes.
- easydel.trainers.distillation_trainer._fn.distillation_step(student_state: EasyDeLState, batch: Mapping[str, Array], teacher_state: EasyDeLState, loss_config: easydel.infra.loss_utils.LossConfig | None = None, learning_rate_fn: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]] = None, partition_spec: jax.sharding.PartitionSpec | None = None, gradient_accumulation_steps: int = 1, is_training: bool = True, temperature: float = 4.0, alpha: float = 0.9, hidden_state_weight: float = 0.0, hidden_state_layers: tuple[int, ...] | None = None, hidden_state_loss: Literal['mse'] = 'mse', attention_weight: float = 0.0, attention_layers: tuple[int, ...] | None = None, attention_normalize: bool = False) tuple[easydel.infra.base_state.EasyDeLState, easydel.infra.loss_utils.LossMetrics][source]#