easydel.trainers.distillation_trainer.distillation_config#

class easydel.trainers.distillation_trainer.distillation_config.DistillationConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: str | None = None, clip_grad: float | None = None, custom_scheduler: tp.Callable[[int], tp.Any] | None = None, dataloader_num_workers: int | None = 0, dataloader_pin_memory: bool | None = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: int | None = None, evaluation_steps: int | None = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: str | None = None, grain_shard_index: int | None = None, grain_shard_count: int | None = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: list[str] | None = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: dict | None = None, learning_rate: float = 5e-05, learning_rate_end: float | None = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: LossConfig | None = None, low_mem_usage: bool = True, max_evaluation_steps: int | None = None, max_sequence_length: int | None = 4096, max_training_steps: int | None = None, per_epoch_training_steps: int | None = None, per_epoch_evaluation_steps: int | None = None, model_name: str | None = None, model_parameters: dict | None = None, metrics_to_show_in_rich_pbar: list[str] | None = None, generation_top_p: float | None = None, generation_top_k: int | None = None, generation_temperature: float | None = None, generation_do_sample: bool | None = None, generation_num_return_sequences: int | None = None, generation_max_new_tokens: int | None = None, generation_shard_inputs: bool = True, generation_interval: int | None = None, generation_prompts: list[str | dict[str, tp.Any]] = <factory>, generation_use_train_prompts: bool = False, generation_num_prompts: int = 1, generation_dataset_prompt_field: str | None = 'prompt', generation_extra_kwargs: dict[str, tp.Any] | None = None, generation_config_overrides: dict[str, tp.Any] | None = None, generation_seed: int | None = None, generation_preview_print: bool = False, generation_log_to_wandb: bool = True, use_esurge_generation: bool = True, esurge_use_tqdm: bool = True, esurge_hbm_utilization: float | None = 0.45, esurge_max_num_seqs: int | None = None, esurge_min_input_pad: int | None = None, esurge_page_size: int | None = 32, esurge_silent_mode: bool = True, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_interval_minutes: float | None = None, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: int | None = None, save_total_limit: int | None = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, shuffle_seed_train: int = 64871, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: dict | None = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: int | None = None, resume_if_possible: bool = True, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: str | None = None, train_on_inputs: bool = True, trainer_prefix: str | None = 'distillationtrainer', truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: jnp.dtype | None = None, track_memory: bool | float = False, use_data_collactor: bool = True, use_grain: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: str | None = None, wandb_name: str | None = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*', weight_distribution_log_steps: int = 50, _can_log_metrics: bool | None = None, _im_a_hidden_checkpoint_manager: Checkpointer | None = None, temperature: float = 2.0, alpha: float = 0.9, hidden_state_loss_weight: float = 0.0, hidden_state_layers: tuple[int, ...] | None = None, hidden_state_loss: tp.Literal['mse'] = 'mse', attention_loss_weight: float = 0.0, attention_layers: tuple[int, ...] | None = None, attention_normalize: bool = False)[source]#

Bases: TrainingArguments

Configuration class for knowledge distillation training.

This configuration extends TrainingArguments with parameters specific to knowledge distillation, where a smaller student model learns to mimic a larger teacher model’s behavior.

Knowledge distillation uses temperature scaling to soften the probability distributions from both models, allowing the student to learn from the teacher’s confidence across all classes rather than just hard labels.

trainer_prefix#

Prefix for trainer logs and checkpoints. Default: “distillationtrainer”

Type

str | None

temperature#

Temperature parameter for softening probability distributions. Higher values create softer distributions, revealing more information about the teacher’s relative confidence across classes. Typical values range from 3.0 to 10.0. Default: 2.0

Type

float

alpha#

Weight balancing distillation loss vs supervised loss. - alpha=1.0: Pure distillation (only learn from teacher) - alpha=0.0: Pure supervised learning (only learn from labels) - 0<alpha<1: Combination of both losses Default: 0.9 (90% distillation, 10% supervised)

Type

float

Example

>>> config = DistillationConfig(
...     temperature=5.0,
...     alpha=0.7,
...     learning_rate=1e-4,
...     num_train_epochs=10
... )

Note

The distillation loss is computed as: Loss = alpha * KL(student/T, teacher/T) + (1-alpha) * CE(student, labels) where T is the temperature parameter.

alpha: float = 0.9#
attention_layers: tuple[int, ...] | None = None#
attention_loss_weight: float = 0.0#
attention_normalize: bool = False#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_state_layers: tuple[int, ...] | None = None#
hidden_state_loss: tp.Literal['mse'] = 'mse'#
hidden_state_loss_weight: float = 0.0#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

temperature: float = 2.0#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

trainer_prefix: str | None = 'distillationtrainer'#