LoRA Optimizations

Custom autograd functions and Triton kernels in Axolotl for optimized LoRA fine-tuning

Inspired by Unsloth, we’ve implemented two optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU (in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was to leverage operator fusion and tensor re-use in order to improve speed and reduce memory usage during the forward and backward passes of these calculations.

We currently support several common model architectures, including (but not limited to):

The set of models we support is currently limited by our attention patching strategy, which assumes (and replaces) specific code blocks for query / key / value and output projections:

ORIGINAL_QKV_CODE = """
    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
    key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
    "\n"
)

ORIGINAL_O_CODE = """
    attn_output = self.o_proj(attn_output)
""".lstrip(
    "\n"
)

Is replaced with:

PATCHED_QKV_CODE = """
    query_states, key_states, value_states = self.apply_qkv(hidden_states)
    query_states = query_states.view(hidden_shape).transpose(1, 2)
    key_states = key_states.view(hidden_shape).transpose(1, 2)
    value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
    "\n"
)

PATCHED_O_CODE = """
    attn_output = self.apply_o(attn_output)
""".lstrip(
    "\n"
)

Where apply_qkv and apply_o are defined in the axolotl.kernels.lora module.

We welcome testing of other model architectures and / or PRs to expand our patching logic to be compatible with more of them.

Usage

These optimizations can be enabled in your Axolotl config YAML file. The lora_mlp_kernel option enables the optimized MLP path, while lora_qkv_kernel and lora_o_kernel enable the fused query-key-value projection and optimized output projection, respectively.

lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true

Requirements

  • One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
  • Targeted LoRA adapters cannot use Dropout
    • This may limit model expressivity / cause overfitting
  • Targeted LoRA adapters cannot have bias terms
    • This may limit model expressivity

Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to be re-finetuned without these features in order to be useful.

Implementation details

Custom autograd functions

The LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the LoRA and base weight computations together and provides a single, efficient backward pass for the entire MLP block.

For attention components, similar optimizations are provided through a function that handles the query, key, and value projections, and a function that handles the output projection. They are designed to work with the existing transformers attention implementation via some monkey-patching logic.

Triton kernels

Two activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for improved speed and memory performance. These kernels handle both the forward and backward passes.

Integration

The custom autograd functions and Triton kernels are designed to work together. The autograd function manages the high-level computation flow and gradient tracking, while calling the Triton kernels for the activation function computation. During the backward pass, the kernel computes both the activation output and the required gradients, which the autograd function then uses to compute the final gradients for the entire computation path.

Future Work

  • Support for additional model architectures
  • Support for the FSDP setting
  • Support for dropout and bias
  • Additional operator fusions