utils.trainer

utils.trainer

Module containing the Trainer class and related functions

Functions

Name Description
add_pose_position_ids use the PoSE technique to extend the context length by randomly skipping
add_position_ids Handle both single-example and batched data.
drop_long_seq Drop samples whose sequence length is either too long (> sequence_len)
setup_trainer Helper method for instantiating and building a (causal or RLHF) trainer.

add_pose_position_ids

utils.trainer.add_pose_position_ids(
    sample,
    max_context_len=32768,
    split_on_token_ids=None,
    chunks=2,
)

use the PoSE technique to extend the context length by randomly skipping positions in the context. We only want to skip right before tokens in the split_on_token_ids list. We should attempt to randomly distribute the skips, but we don’t need the final position_ids to be the full context_len. There may be multiple turns in the context, so we want to make sure we take into account the maximum possible number of skips remaining in each sample.

add_position_ids

utils.trainer.add_position_ids(sample)

Handle both single-example and batched data. - single example: sample[‘input_ids’] is a list[int] - batched data: sample[‘input_ids’] is a list[list[int]]

drop_long_seq

utils.trainer.drop_long_seq(sample, sequence_len=2048, min_sequence_len=2)

Drop samples whose sequence length is either too long (> sequence_len) or too short (< min_sequence_len).

Works for both single-example (list[int]) or batched (list[list[int]]).

setup_trainer

utils.trainer.setup_trainer(
    cfg,
    train_dataset,
    eval_dataset,
    model,
    tokenizer,
    processor,
    total_num_steps,
    model_ref=None,
    peft_config=None,
)

Helper method for instantiating and building a (causal or RLHF) trainer.

Parameters

Name Type Description Default
cfg Axolotl config object containing training parameters. required
train_dataset Dataset to use for training. required
eval_dataset Dataset to use for evaluation. required
model The model to train. required
tokenizer Tokenizer for processing text input. required
processor Processor for data preparation. required
total_num_steps The total number of training steps. required
model_ref Optional reference model for RLHF training. Default is None. None
peft_config Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None. None

Returns

Name Type Description
A trainer instance (either HFRLTrainer or HFCausalTrainer) configured based on the provided parameters.