integrations.kd.trainer

integrations.kd.trainer

KD trainer

Classes

Name Description
AxolotlKDTrainer Custom trainer subclass for Knowledge Distillation (KD)

AxolotlKDTrainer

integrations.kd.trainer.AxolotlKDTrainer(
    self,
    *_args,
    bench_data_collator=None,
    eval_data_collator=None,
    dataset_tags=None,
    **kwargs,
)

Custom trainer subclass for Knowledge Distillation (KD)

Methods

Name Description
compute_loss How the loss is computed by Trainer. By default, all models return the loss in the first element.
compute_loss
integrations.kd.trainer.AxolotlKDTrainer.compute_loss(
    model,
    inputs,
    return_outputs=False,
    num_items_in_batch=None,
)

How the loss is computed by Trainer. By default, all models return the loss in the first element.

Subclass and override for custom behavior.