utils.trainer
utils.trainer
Module containing the Trainer class and related functions
Functions
Name | Description |
---|---|
add_pose_position_ids | use the PoSE technique to extend the context length by randomly skipping |
add_position_ids | Handle both single-example and batched data. |
drop_long_seq | Drop samples whose sequence length is either too long (> sequence_len) |
setup_trainer | Helper method for instantiating and building a (causal or RLHF) trainer. |
add_pose_position_ids
utils.trainer.add_pose_position_ids(
sample,=32768,
max_context_len=None,
split_on_token_ids=2,
chunks )
use the PoSE technique to extend the context length by randomly skipping positions in the context. We only want to skip right before tokens in the split_on_token_ids list. We should attempt to randomly distribute the skips, but we don’t need the final position_ids to be the full context_len. There may be multiple turns in the context, so we want to make sure we take into account the maximum possible number of skips remaining in each sample.
add_position_ids
utils.trainer.add_position_ids(sample)
Handle both single-example and batched data. - single example: sample[‘input_ids’] is a list[int] - batched data: sample[‘input_ids’] is a list[list[int]]
drop_long_seq
=2048, min_sequence_len=2) utils.trainer.drop_long_seq(sample, sequence_len
Drop samples whose sequence length is either too long (> sequence_len) or too short (< min_sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]).
setup_trainer
utils.trainer.setup_trainer(
cfg,
train_dataset,
eval_dataset,
model,
tokenizer,
processor,
total_num_steps,=None,
model_ref=None,
peft_config )
Helper method for instantiating and building a (causal or RLHF) trainer.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | Axolotl config object containing training parameters. | required | |
train_dataset | Dataset to use for training. | required | |
eval_dataset | Dataset to use for evaluation. | required | |
model | The model to train. | required | |
tokenizer | Tokenizer for processing text input. | required | |
processor | Processor for data preparation. | required | |
total_num_steps | The total number of training steps. | required | |
model_ref | Optional reference model for RLHF training. Default is None. | None |
|
peft_config | Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None. | None |
Returns
Name | Type | Description |
---|---|---|
A trainer instance (either HFRLTrainer or HFCausalTrainer ) configured based on the provided parameters. |