core.trainers.base

core.trainers.base

Module for customized trainers

Classes

Name Description
AxolotlTrainer Extend the base Trainer for axolotl helpers

AxolotlTrainer

core.trainers.base.AxolotlTrainer(
    self,
    *_args,
    bench_data_collator=None,
    eval_data_collator=None,
    dataset_tags=None,
    **kwargs,
)

Extend the base Trainer for axolotl helpers

Methods

Name Description
get_eval_dataloader Get dataloader for evaluation
get_train_dataloader Get dataloader for training
log Log logs on the various objects watching training, including stored metrics.
push_to_hub Overwrite the push_to_hub method in order to force-add the tags when pushing the
training_step Perform a training step on a batch of inputs. Overrides the
get_eval_dataloader
core.trainers.base.AxolotlTrainer.get_eval_dataloader(eval_dataset=None)

Get dataloader for evaluation

get_train_dataloader
core.trainers.base.AxolotlTrainer.get_train_dataloader()

Get dataloader for training

log
core.trainers.base.AxolotlTrainer.log(logs, start_time=None)

Log logs on the various objects watching training, including stored metrics.

Parameters
Name Type Description Default
logs dict[str, float] The values to log. required
start_time float | None The start of training. None
push_to_hub
core.trainers.base.AxolotlTrainer.push_to_hub(*args, **kwargs)

Overwrite the push_to_hub method in order to force-add the tags when pushing the model on the Hub. Please refer to ~transformers.Trainer.push_to_hub for more details.

training_step
core.trainers.base.AxolotlTrainer.training_step(
    model,
    inputs,
    num_items_in_batch=None,
)

Perform a training step on a batch of inputs. Overrides the transformers.trainer.Trainer method to handle sequence parallelism if enabled.

Parameters
Name Type Description Default
model nn.Module Model to perform training step for. required
inputs dict[str, torch.Tensor | Any] Dictionary mapping. required