easydel.trainers.distillation_trainer.distillation_trainer#
- class easydel.trainers.distillation_trainer.distillation_trainer.DistillationTrainer(arguments: DistillationConfig, processing_class: ProcessingClassType, student_model: EasyDeLBaseModule | EasyDeLState | None = None, teacher_model: EasyDeLBaseModule | EasyDeLState | None = None, train_dataset: Dataset | None = None, eval_dataset: Dataset | dict[str, Dataset] | None = None, data_collator: DataCollatorForCompletionOnlyLM | None = None)[source]#
Bases:
TrainerKnowledge distillation trainer for model compression.
Implements knowledge distillation to transfer knowledge from a larger teacher model to a smaller student model. The training combines distillation loss (KL divergence between teacher and student outputs) with standard supervised loss, controlled by the alpha parameter.
Key features: - Temperature-scaled softmax for softer probability distributions - Configurable balance between distillation and supervised loss - Support for both language and multimodal models - Efficient JAX-based implementation with JIT compilation
- The distillation loss is computed as:
Loss = α * KL(student/T, teacher/T) * T² + (1-α) * CE(student, labels)
where T is the temperature parameter.
- teacher_state#
State of the teacher model (frozen during training)
- arguments#
DistillationConfig with training hyperparameters
Example
>>> config = DistillationConfig( ... temperature=3.0, ... alpha=0.7, ... learning_rate=2e-5 ... ) >>> trainer = DistillationTrainer( ... arguments=config, ... student_model=student, ... teacher_model=teacher, ... train_dataset=dataset, ... processing_class=tokenizer ... ) >>> trainer.train()
- arguments: DistillationConfig#
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- teacher_state: EasyDeLState#