# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as tp
from eformer.loggings import get_logger
from jax.sharding import NamedSharding, PartitionSpec
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.base_state import EasyDeLState
from easydel.infra.utils import ProcessingClassType
from easydel.utils import Registry
from easydel.utils.compiling_utils import ejit
from ..trainer import Trainer
from ..trainer_protocol import TrainerConfigureFunctionOutput
from ..utils import DataCollatorForCompletionOnlyLM
from ._fn import distillation_step
from .distillation_config import DistillationConfig
if tp.TYPE_CHECKING:
from datasets import Dataset
logger = get_logger(__name__)
[docs]@Registry.register("trainer", "distillation")
class DistillationTrainer(Trainer):
"""Knowledge distillation trainer for model compression.
Implements knowledge distillation to transfer knowledge from a larger
teacher model to a smaller student model. The training combines distillation
loss (KL divergence between teacher and student outputs) with standard
supervised loss, controlled by the alpha parameter.
Key features:
- Temperature-scaled softmax for softer probability distributions
- Configurable balance between distillation and supervised loss
- Support for both language and multimodal models
- Efficient JAX-based implementation with JIT compilation
The distillation loss is computed as:
Loss = α * KL(student/T, teacher/T) * T² + (1-α) * CE(student, labels)
where T is the temperature parameter.
Attributes:
teacher_state: State of the teacher model (frozen during training)
arguments: DistillationConfig with training hyperparameters
Example:
>>> config = DistillationConfig(
... temperature=3.0,
... alpha=0.7,
... learning_rate=2e-5
... )
>>> trainer = DistillationTrainer(
... arguments=config,
... student_model=student,
... teacher_model=teacher,
... train_dataset=dataset,
... processing_class=tokenizer
... )
>>> trainer.train()
""" # noqa
teacher_state: EasyDeLState
arguments: DistillationConfig # type hinting
def __init__(
self,
arguments: DistillationConfig,
processing_class: ProcessingClassType,
student_model: EasyDeLBaseModule | EasyDeLState | None = None,
teacher_model: EasyDeLBaseModule | EasyDeLState | None = None,
train_dataset: Dataset | None = None,
eval_dataset: Dataset | dict[str, Dataset] | None = None,
data_collator: DataCollatorForCompletionOnlyLM | None = None,
):
tokenizer = processing_class
if hasattr(processing_class, "tokenizer"):
tokenizer = processing_class.tokenizer
if getattr(tokenizer, "pad_token", None) is None and hasattr(tokenizer, "eos_token"):
tokenizer.pad_token = tokenizer.eos_token
assert isinstance(arguments, DistillationConfig), "passed argument must be a `DistillationConfig`."
self.arguments = arguments
if not isinstance(student_model, EasyDeLState):
student_model = student_model.to_state()
if not isinstance(teacher_model, EasyDeLState):
teacher_model = teacher_model.to_state()
self.teacher_state = teacher_model
super().__init__(
arguments=arguments,
dataset_train=train_dataset,
dataset_eval=eval_dataset,
model_state=student_model,
data_collator=data_collator,
processing_class=processing_class,
)
@property
def _train_shared_fn_extra_args(self) -> tuple[tp.Any]:
return (self.teacher_state,)
@property
def _eval_shared_fn_extra_args(self) -> tuple[tp.Any]:
return (self.teacher_state,)