easydel.trainers.distillation_trainer.distillation_trainer#

class easydel.trainers.distillation_trainer.distillation_trainer.DistillationTrainer(arguments: DistillationConfig, processing_class: ProcessingClassType, student_model: EasyDeLBaseModule | EasyDeLState | None = None, teacher_model: EasyDeLBaseModule | EasyDeLState | None = None, train_dataset: Dataset | None = None, eval_dataset: Dataset | dict[str, Dataset] | None = None, data_collator: DataCollatorForCompletionOnlyLM | None = None)[source]#

Bases: Trainer

Knowledge distillation trainer for model compression.

Implements knowledge distillation to transfer knowledge from a larger teacher model to a smaller student model. The training combines distillation loss (KL divergence between teacher and student outputs) with standard supervised loss, controlled by the alpha parameter.

Key features: - Temperature-scaled softmax for softer probability distributions - Configurable balance between distillation and supervised loss - Support for both language and multimodal models - Efficient JAX-based implementation with JIT compilation

The distillation loss is computed as:

Loss = α * KL(student/T, teacher/T) * T² + (1-α) * CE(student, labels)

where T is the temperature parameter.

teacher_state#

State of the teacher model (frozen during training)

Type

easydel.infra.base_state.EasyDeLState

arguments#

DistillationConfig with training hyperparameters

Type

easydel.trainers.distillation_trainer.distillation_config.DistillationConfig

Example

>>> config = DistillationConfig(
...     temperature=3.0,
...     alpha=0.7,
...     learning_rate=2e-5
... )
>>> trainer = DistillationTrainer(
...     arguments=config,
...     student_model=student,
...     teacher_model=teacher,
...     train_dataset=dataset,
...     processing_class=tokenizer
... )
>>> trainer.train()
arguments: DistillationConfig#
configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

teacher_state: EasyDeLState#