RLHF (Beta)
Overview
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback. Various methods include, but not limited to:
- Direct Preference Optimization (DPO)
- Identity Preference Optimization (IPO)
- Kahneman-Tversky Optimization (KTO)
- Odds Ratio Preference Optimization (ORPO)
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
RLHF using Axolotl
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
We rely on the TRL library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.
You can find what each method supports by going into src/axolotl/prompt_strategies/{method}
where {method}
is one of our supported methods. The type:
can be retrieved from {method}.{function_name}
.
DPO
Example config:
rl: dpo
datasets:
- path: Intel/orca_dpo_pairs
split: train
type: chatml.intel
- path: argilla/ultrafeedback-binarized-preferences
split: train
type: chatml
DPO supports the following types with the following dataset format:
chatml.argilla
{
"system": "...", // optional
"instruction": "...",
"chosen_response": "...",
"rejected_response": "..."
}
chatml.argilla_chat
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
chatml.icr
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
chatml.intel
{
"system": "...", // optional
"question": "...",
"chosen": "...",
"rejected": "..."
}
chatml.prompt_pairs
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}
chatml.ultra
{
"system": "...", // optional
"prompt": "...",
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
llama3.argilla
{
"system": "...", // optional
"instruction": "...",
"chosen_response": "...",
"rejected_response": "..."
}
llama3.argilla_chat
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
llama3.icr
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
llama3.intel
{
"system": "...", // optional
"question": "...",
"chosen": "...",
"rejected": "..."
}
llama3.prompt_pairs
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}
llama3.ultra
{
"system": "...", // optional
"prompt": "...",
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
zephyr.nectar
{
"prompt": "...",
"answers": [
{
"answer": "...",
"rank": 1
},
{
"answer": "...",
"rank": 2
}
// ... more answers with ranks
]
}
chat_template.default
rl: dpo
datasets:
- path: ...
split: train
type: chat_template.default
field_messages: "messages"
field_chosen: "chosen"
field_rejected: "rejected"
message_property_mappings:
role: role
content: content
roles:
user: ["user"]
assistant: ["assistant"]
system: ["system"]
Sample input format:
{
"messages": [
{
"role": "system",
"content": "..."
},
{
"role": "user",
"content": "..."
},
// ... more messages
],
"chosen": {
"role": "assistant",
"content": "..."
},
"rejected": {
"role": "assistant",
"content": "..."
}
}
user_defined.default
For custom behaviors,
rl: dpo
datasets:
- path: ...
split: train
type: user_defined.default
field_prompt: "prompt"
field_system: "system"
field_chosen: "chosen"
field_rejected: "rejected"
prompt_format: "{prompt}"
chosen_format: "{chosen}"
rejected_format: "{rejected}"
The input format is a simple JSON input with customizable fields based on the above config.
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}
IPO
As IPO is just DPO with a different loss function, all supported dataset formats for DPO are also supported for IPO.
rl: ipo
ORPO
Paper: https://arxiv.org/abs/2403.07691
rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false
chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: chat_template.argilla
ORPO supports the following types with the following dataset format:
chat_template.argilla
{
"system": "...", // optional
"prompt": "...", // if available, will be taken as user message for single-turn instead of from list below
// chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
KTO
rl: kto
rl_beta: 0.1 # default
kto_desirable_weight: 1.0 # default
kto_undesirable_weight: 1.0 # default
remove_unused_columns: false
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
type: llama3.ultra
split: train
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
KTO supports the following types with the following dataset format:
chatml.argilla
{
"system": "...", // optional
"instruction": "...",
"completion": "..."
}
chatml.argilla_chat
{
"chosen": [
{"role": "user", "content": "..."}
],
"completion": [
{"role": "assistant", "content": "..."}
]
}
chatml.intel
{
"system": "...", // optional
"question": "...",
"completion": "..."
}
chatml.prompt_pairs
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
chatml.ultra
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
llama3.argilla
{
"system": "...", // optional
"instruction": "...",
"completion": "..."
}
llama3.argilla_chat
{
"completion": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
llama3.intel
{
"system": "...", // optional
"question": "...",
"completion": "..."
}
llama3.prompt_pairs
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
llama3.ultra
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
user_defined.default
For custom behaviors,
rl: kto
datasets:
- path: ...
split: train
type: user_defined.default
field_prompt: "prompt"
field_system: "system"
field_completion: "completion"
field_label: "label"
prompt_format: "{prompt}"
completion_format: "{completion}"
The input format is a simple JSON input with customizable fields based on the above config.
{
"system": "...", // optional
"prompt": "...",
"completion": "...",
"label": "..."
}
GRPO
Check out our GRPO cookbook.
GRPO uses custom reward functions and transformations. Please have them ready locally.
For ex, to load OpenAI’s GSM8K and use a random reward for completions:
# rewards.py
import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
= example["answer"].split("####")[-1].strip().replace(",", "")
label return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}return transform_fn, {"remove_columns": ["question"]}
rl: grpo
trl:
beta: 0.001
max_completion_length: 256
use_vllm: True
vllm_device: auto
vllm_gpu_memory_utilization: 0.15
num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
type: rewards.oai_gsm8k_transform # format: '{file_name}.{fn_name}'
To see other examples of custom reward functions, please see TRL GRPO Docs.
To see description of the configs, please see TRLConfig.
SimPO
SimPO uses CPOTrainer but with alternative loss function.
rl: simpo
rl_beta: 0.1 # default in CPOTrainer
cpo_alpha: 1.0 # default in CPOTrainer
simpo_gamma: 0.5 # default in CPOTrainer
This method uses the same dataset format as DPO.
Using local dataset files
datasets:
- ds_type: json
data_files:
- orca_rlhf.jsonl
split: train
type: chatml.intel
TRL auto-unwrapping for PEFT
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
# load ref model when adapter training.
rl_adapter_ref_model: true