core.trainers.trl
core.trainers.trl
Module for TRL PPO trainer
Classes
Name | Description |
---|---|
AxolotlCPOTrainer | Extend the base CPOTrainer for axolotl helpers |
AxolotlKTOTrainer | Extend the base KTOTrainer for axolotl helpers |
AxolotlORPOTrainer | Extend the base ORPOTrainer for axolotl helpers |
AxolotlPRMTrainer | Extend the base trl.PRMTrainer for axolotl helpers |
AxolotlRewardTrainer | Extend the base RewardTrainer for axolotl helpers |
TRLPPOTrainer | Wrapper for TRL PPO trainer to handle customizations |
AxolotlCPOTrainer
core.trainers.trl.AxolotlCPOTrainer()
Extend the base CPOTrainer for axolotl helpers
Methods
Name | Description |
---|---|
get_batch_loss_metrics | Compute the CPO loss and other metrics for the given batch of inputs for train or test. |
get_batch_loss_metrics
core.trainers.trl.AxolotlCPOTrainer.get_batch_loss_metrics(
model,
batch,='train',
train_eval )
Compute the CPO loss and other metrics for the given batch of inputs for train or test.
AxolotlKTOTrainer
core.trainers.trl.AxolotlKTOTrainer()
Extend the base KTOTrainer for axolotl helpers
AxolotlORPOTrainer
core.trainers.trl.AxolotlORPOTrainer()
Extend the base ORPOTrainer for axolotl helpers
Methods
Name | Description |
---|---|
get_batch_loss_metrics | Compute the ORPO loss and other metrics for the given batch of inputs for train or test. |
get_batch_loss_metrics
core.trainers.trl.AxolotlORPOTrainer.get_batch_loss_metrics(
model,
batch,='train',
train_eval )
Compute the ORPO loss and other metrics for the given batch of inputs for train or test.
AxolotlPRMTrainer
core.trainers.trl.AxolotlPRMTrainer()
Extend the base trl.PRMTrainer for axolotl helpers
AxolotlRewardTrainer
core.trainers.trl.AxolotlRewardTrainer()
Extend the base RewardTrainer for axolotl helpers
TRLPPOTrainer
core.trainers.trl.TRLPPOTrainer()
Wrapper for TRL PPO trainer to handle customizations