train
train
Prepare and train a model on a dataset. Can also infer from a model or merge lora
Functions
create_model_card
train.create_model_card(cfg, trainer)
Create a model card for the trained model if needed.
Parameters
cfg |
DictDefault |
Dictionary mapping axolotl config keys to values. |
required |
trainer |
Trainer |
The trainer object with model card creation capabilities. |
required |
determine_resume_checkpoint
train.determine_resume_checkpoint(cfg)
Determine the checkpoint to resume from based on configuration.
Parameters
cfg |
DictDefault |
Dictionary mapping axolotl config keys to values. |
required |
Returns
|
str | None |
Path to the checkpoint to resume from, or None if not resuming. |
execute_training
train.execute_training(cfg, trainer, resume_from_checkpoint)
Execute the training process with appropriate SDP kernel configurations.
Parameters
cfg |
DictDefault |
Dictionary mapping axolotl config keys to values. |
required |
trainer |
Any |
The configured trainer object. |
required |
resume_from_checkpoint |
str | None |
Path to checkpoint to resume from, if applicable. |
required |
handle_untrained_tokens_fix
train.handle_untrained_tokens_fix(
cfg,
model,
tokenizer,
train_dataset,
safe_serialization,
)
Apply fixes for untrained tokens if configured.
Parameters
cfg |
DictDefault |
Dictionary mapping axolotl config keys to values. |
required |
model |
PreTrainedModel |
The model to apply fixes to. |
required |
tokenizer |
PreTrainedTokenizer |
The tokenizer for token identification. |
required |
train_dataset |
Dataset |
The training dataset to use. |
required |
safe_serialization |
bool |
Whether to use safe serialization when saving. |
required |
save_initial_configs
train.save_initial_configs(cfg, tokenizer, model, peft_config, processor)
Save initial configurations before training.
Parameters
cfg |
DictDefault |
Dictionary mapping axolotl config keys to values. |
required |
tokenizer |
PreTrainedTokenizer |
The tokenizer to save. |
required |
model |
PreTrainedModel |
The model to save configuration for. |
required |
peft_config |
PeftConfig | None |
The PEFT configuration to save if applicable. |
required |
save_trained_model
train.save_trained_model(cfg, trainer, model, safe_serialization)
Save the trained model according to configuration and training setup.
Parameters
cfg |
DictDefault |
Dictionary mapping axolotl config keys to values. |
required |
trainer |
Any |
The trainer object. |
required |
model |
PreTrainedModel |
The trained model to save. |
required |
safe_serialization |
bool |
Whether to use safe serialization. |
required |
setup_model_and_tokenizer
train.setup_model_and_tokenizer(cfg)
Load the tokenizer, processor (for multimodal models), and model based on configuration.
Parameters
cfg |
DictDefault |
Dictionary mapping axolotl config keys to values. |
required |
Returns
|
tuple[PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None] |
Tuple containing model, tokenizer, peft_config (if LoRA / QLoRA, else None ), and processor (if multimodal, else None ). |
setup_model_and_trainer
train.setup_model_and_trainer(cfg, dataset_meta)
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
trainer setup.
Parameters
cfg |
DictDefault |
The configuration dictionary with training parameters. |
required |
dataset_meta |
TrainDatasetMeta |
Object with training, validation datasets and metadata. |
required |
Returns
|
tuple[HFRLTrainerBuilder | HFCausalTrainerBuilder, PeftModel | PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None] |
Tuple of: - Trainer (Causal or RLHF) - Model - Tokenizer - PEFT config - Processor |
setup_model_card
train.setup_model_card(cfg)
Set up the Axolotl badge and add the Axolotl config to the model card if available.
Parameters
cfg |
DictDefault |
Dictionary mapping axolotl config keys to values. |
required |
setup_reference_model
train.setup_reference_model(cfg, tokenizer)
Set up the reference model for RL training if needed.
Parameters
cfg |
DictDefault |
Dictionary mapping axolotl config keys to values. |
required |
tokenizer |
PreTrainedTokenizer |
The tokenizer to use for the reference model. |
required |
Returns
|
PreTrainedModel | None |
Reference model if needed for RL training, None otherwise. |
setup_signal_handler
train.setup_signal_handler(cfg, model, safe_serialization)
Set up signal handler for graceful termination.
Parameters
cfg |
DictDefault |
Dictionary mapping axolotl config keys to values. |
required |
model |
PreTrainedModel |
The model to save on termination |
required |
safe_serialization |
bool |
Whether to use safe serialization when saving |
required |
train
train.train(cfg, dataset_meta)
Train a model on the given dataset.
Parameters
cfg |
DictDefault |
The configuration dictionary with training parameters |
required |
dataset_meta |
TrainDatasetMeta |
Object with training, validation datasets and metadata |
required |
Returns
|
tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer] |
Tuple of (model, tokenizer) after training |