core.trainers.trl

core.trainers.trl

Module for TRL PPO trainer

Classes

Name Description
AxolotlCPOTrainer Extend the base CPOTrainer for axolotl helpers
AxolotlKTOTrainer Extend the base KTOTrainer for axolotl helpers
AxolotlORPOTrainer Extend the base ORPOTrainer for axolotl helpers
AxolotlPRMTrainer Extend the base trl.PRMTrainer for axolotl helpers
AxolotlRewardTrainer Extend the base RewardTrainer for axolotl helpers
TRLPPOTrainer Wrapper for TRL PPO trainer to handle customizations

AxolotlCPOTrainer

core.trainers.trl.AxolotlCPOTrainer()

Extend the base CPOTrainer for axolotl helpers

Methods

Name Description
get_batch_loss_metrics Compute the CPO loss and other metrics for the given batch of inputs for train or test.
get_batch_loss_metrics
core.trainers.trl.AxolotlCPOTrainer.get_batch_loss_metrics(
    model,
    batch,
    train_eval='train',
)

Compute the CPO loss and other metrics for the given batch of inputs for train or test.

AxolotlKTOTrainer

core.trainers.trl.AxolotlKTOTrainer()

Extend the base KTOTrainer for axolotl helpers

AxolotlORPOTrainer

core.trainers.trl.AxolotlORPOTrainer()

Extend the base ORPOTrainer for axolotl helpers

Methods

Name Description
get_batch_loss_metrics Compute the ORPO loss and other metrics for the given batch of inputs for train or test.
get_batch_loss_metrics
core.trainers.trl.AxolotlORPOTrainer.get_batch_loss_metrics(
    model,
    batch,
    train_eval='train',
)

Compute the ORPO loss and other metrics for the given batch of inputs for train or test.

AxolotlPRMTrainer

core.trainers.trl.AxolotlPRMTrainer()

Extend the base trl.PRMTrainer for axolotl helpers

AxolotlRewardTrainer

core.trainers.trl.AxolotlRewardTrainer()

Extend the base RewardTrainer for axolotl helpers

TRLPPOTrainer

core.trainers.trl.TRLPPOTrainer()

Wrapper for TRL PPO trainer to handle customizations