RLHF (Beta)

Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback.

Overview

Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback. Various methods include, but not limited to:

RLHF using Axolotl

Important

This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.

We rely on the TRL library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.

Tip

You can find what each method supports by going into src/axolotl/prompt_strategies/{method} where {method} is one of our supported methods. The type: can be retrieved from {method}.{function_name}.

DPO

Example config:

rl: dpo
datasets:
  - path: Intel/orca_dpo_pairs
    split: train
    type: chatml.intel
  - path: argilla/ultrafeedback-binarized-preferences
    split: train
    type: chatml

DPO supports the following types with the following dataset format:

chatml.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "chosen_response": "...",
    "rejected_response": "..."
}

chatml.argilla_chat

{
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

chatml.icr

{
    "system": "...", // optional
    "input": "...",
    "chosen": "...",
    "rejected": "..."
}

chatml.intel

{
    "system": "...", // optional
    "question": "...",
    "chosen": "...",
    "rejected": "..."
}

chatml.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": "...",
    "rejected": "..."
}

chatml.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

llama3.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "chosen_response": "...",
    "rejected_response": "..."
}

llama3.argilla_chat

{
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

llama3.icr

{
    "system": "...", // optional
    "input": "...",
    "chosen": "...",
    "rejected": "..."
}

llama3.intel

{
    "system": "...", // optional
    "question": "...",
    "chosen": "...",
    "rejected": "..."
}

llama3.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": "...",
    "rejected": "..."
}

llama3.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

zephyr.nectar

{
    "prompt": "...",
    "answers": [
        {
            "answer": "...",
            "rank": 1
        },
        {
            "answer": "...",
            "rank": 2
        }
        // ... more answers with ranks
    ]
}

chat_template.default

rl: dpo
datasets:
  - path: ...
    split: train
    type: chat_template.default
    field_messages: "messages"
    field_chosen: "chosen"
    field_rejected: "rejected"
    message_property_mappings:
      role: role
      content: content
    roles:
      user: ["user"]
      assistant: ["assistant"]
      system: ["system"]

Sample input format:

{
    "messages": [
        {
            "role": "system",
            "content": "..."
        },
        {
            "role": "user",
            "content": "..."
        },
        // ... more messages
    ],
    "chosen": {
        "role": "assistant",
        "content": "..."
    },
    "rejected": {
        "role": "assistant",
        "content": "..."
    }
}

user_defined.default

For custom behaviors,

rl: dpo
datasets:
  - path: ...
    split: train
    type: user_defined.default

    field_prompt: "prompt"
    field_system: "system"
    field_chosen: "chosen"
    field_rejected: "rejected"
    prompt_format: "{prompt}"
    chosen_format: "{chosen}"
    rejected_format: "{rejected}"

The input format is a simple JSON input with customizable fields based on the above config.

{
    "system": "...",  // optional
    "prompt": "...",
    "chosen": "...",
    "rejected": "..."
}

IPO

As IPO is just DPO with a different loss function, all supported options for DPO works here.

rl: ipo

ORPO

Paper: https://arxiv.org/abs/2403.07691

rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false

chat_template: chatml
datasets:
  - path: argilla/ultrafeedback-binarized-preferences-cleaned
    type: chat_template.argilla

ORPO supports the following types with the following dataset format:

chat_template.argilla

{
    "system": "...",  // optional
    "prompt": "...",  // if available, will be taken as user message for single-turn instead of from list below

    // chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

KTO

rl: kto
rl_beta: 0.5
kto_desirable_weight: 0.2

remove_unused_columns: false

datasets:
  - path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
    type: llama3.ultra
    split: train

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true

KTO supports the following types with the following dataset format:

chatml.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "completion": "..."
}

chatml.argilla_chat

{
    "chosen": [
        {"role": "user", "content": "..."}
    ],
    "completion": [
        {"role": "assistant", "content": "..."}
    ]
}

chatml.intel

{
    "system": "...", // optional
    "question": "...",
    "completion": "..."
}

chatml.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

chatml.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

llama3.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "completion": "..."
}

llama3.argilla_chat

{
    "completion": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

llama3.intel

{
    "system": "...", // optional
    "question": "...",
    "completion": "..."
}

llama3.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

llama3.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

user_defined.default

For custom behaviors,

rl: kto
datasets:
  - path: ...
    split: train
    type: user_defined.default

    field_prompt: "prompt"
    field_system: "system"
    field_completion: "completion"
    field_label: "label"
    prompt_format: "{prompt}"
    completion_format: "{completion}"

The input format is a simple JSON input with customizable fields based on the above config.

{
    "system": "...",  // optional
    "prompt": "...",
    "completion": "...",
    "label": "..."
}

Using local dataset files

datasets:
  - ds_type: json
    data_files:
      - orca_rlhf.jsonl
    split: train
    type: chatml.intel

TRL auto-unwrapping for PEFT

TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:

# load ref model when adapter training.
rl_adapter_ref_model: true