train

train

Prepare and train a model on a dataset. Can also infer from a model or merge lora

Functions

Name Description
create_model_card Create a model card for the trained model if needed.
determine_resume_checkpoint Determine the checkpoint to resume from based on configuration.
execute_training Execute the training process with appropriate SDP kernel configurations.
handle_untrained_tokens_fix Apply fixes for untrained tokens if configured.
save_initial_configs Save initial configurations before training.
save_trained_model Save the trained model according to configuration and training setup.
setup_model_and_tokenizer Load the tokenizer, processor (for multimodal models), and model based on configuration.
setup_model_and_trainer Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
setup_model_card Set up the Axolotl badge and add the Axolotl config to the model card if available.
setup_reference_model Set up the reference model for RL training if needed.
setup_signal_handler Set up signal handler for graceful termination.
train Train a model on the given dataset.

create_model_card

train.create_model_card(cfg, trainer)

Create a model card for the trained model if needed.

Parameters

Name Type Description Default
cfg DictDefault Dictionary mapping axolotl config keys to values. required
trainer Trainer The trainer object with model card creation capabilities. required

determine_resume_checkpoint

train.determine_resume_checkpoint(cfg)

Determine the checkpoint to resume from based on configuration.

Parameters

Name Type Description Default
cfg DictDefault Dictionary mapping axolotl config keys to values. required

Returns

Name Type Description
str | None Path to the checkpoint to resume from, or None if not resuming.

execute_training

train.execute_training(cfg, trainer, resume_from_checkpoint)

Execute the training process with appropriate SDP kernel configurations.

Parameters

Name Type Description Default
cfg DictDefault Dictionary mapping axolotl config keys to values. required
trainer Any The configured trainer object. required
resume_from_checkpoint str | None Path to checkpoint to resume from, if applicable. required

handle_untrained_tokens_fix

train.handle_untrained_tokens_fix(
    cfg,
    model,
    tokenizer,
    train_dataset,
    safe_serialization,
)

Apply fixes for untrained tokens if configured.

Parameters

Name Type Description Default
cfg DictDefault Dictionary mapping axolotl config keys to values. required
model PreTrainedModel The model to apply fixes to. required
tokenizer PreTrainedTokenizer The tokenizer for token identification. required
train_dataset Dataset The training dataset to use. required
safe_serialization bool Whether to use safe serialization when saving. required

save_initial_configs

train.save_initial_configs(cfg, tokenizer, model, peft_config, processor)

Save initial configurations before training.

Parameters

Name Type Description Default
cfg DictDefault Dictionary mapping axolotl config keys to values. required
tokenizer PreTrainedTokenizer The tokenizer to save. required
model PreTrainedModel The model to save configuration for. required
peft_config PeftConfig | None The PEFT configuration to save if applicable. required

save_trained_model

train.save_trained_model(cfg, trainer, model, safe_serialization)

Save the trained model according to configuration and training setup.

Parameters

Name Type Description Default
cfg DictDefault Dictionary mapping axolotl config keys to values. required
trainer Any The trainer object. required
model PreTrainedModel The trained model to save. required
safe_serialization bool Whether to use safe serialization. required

setup_model_and_tokenizer

train.setup_model_and_tokenizer(cfg)

Load the tokenizer, processor (for multimodal models), and model based on configuration.

Parameters

Name Type Description Default
cfg DictDefault Dictionary mapping axolotl config keys to values. required

Returns

Name Type Description
tuple[PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None] Tuple containing model, tokenizer, peft_config (if LoRA / QLoRA, else None), and processor (if multimodal, else None).

setup_model_and_trainer

train.setup_model_and_trainer(cfg, dataset_meta)

Load model, tokenizer, trainer, etc. Helper function to encapsulate the full trainer setup.

Parameters

Name Type Description Default
cfg DictDefault The configuration dictionary with training parameters. required
dataset_meta TrainDatasetMeta Object with training, validation datasets and metadata. required

Returns

Name Type Description
tuple[HFRLTrainerBuilder | HFCausalTrainerBuilder, PeftModel | PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None] Tuple of: - Trainer (Causal or RLHF) - Model - Tokenizer - PEFT config - Processor

setup_model_card

train.setup_model_card(cfg)

Set up the Axolotl badge and add the Axolotl config to the model card if available.

Parameters

Name Type Description Default
cfg DictDefault Dictionary mapping axolotl config keys to values. required

setup_reference_model

train.setup_reference_model(cfg, tokenizer)

Set up the reference model for RL training if needed.

Parameters

Name Type Description Default
cfg DictDefault Dictionary mapping axolotl config keys to values. required
tokenizer PreTrainedTokenizer The tokenizer to use for the reference model. required

Returns

Name Type Description
PreTrainedModel | None Reference model if needed for RL training, None otherwise.

setup_signal_handler

train.setup_signal_handler(cfg, model, safe_serialization)

Set up signal handler for graceful termination.

Parameters

Name Type Description Default
cfg DictDefault Dictionary mapping axolotl config keys to values. required
model PreTrainedModel The model to save on termination required
safe_serialization bool Whether to use safe serialization when saving required

train

train.train(cfg, dataset_meta)

Train a model on the given dataset.

Parameters

Name Type Description Default
cfg DictDefault The configuration dictionary with training parameters required
dataset_meta TrainDatasetMeta Object with training, validation datasets and metadata required

Returns

Name Type Description
tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer] Tuple of (model, tokenizer) after training