LoRA Optimizations
Inspired by Unsloth, we’ve implemented two optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU (in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was to leverage operator fusion and tensor re-use in order to improve speed and reduce memory usage during the forward and backward passes of these calculations.
We currently support several common model architectures, including (but not limited to):
llama
mistral
qwen2
gemma
gemma2
The set of models we support is currently limited by our attention patching strategy, which assumes (and replaces) specific code blocks for query / key / value and output projections:
= """
ORIGINAL_QKV_CODE query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
= """
ORIGINAL_O_CODE attn_output = self.o_proj(attn_output)
""".lstrip(
"\n"
)
Is replaced with:
= """
PATCHED_QKV_CODE query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
= """
PATCHED_O_CODE attn_output = self.apply_o(attn_output)
""".lstrip(
"\n"
)
Where apply_qkv
and apply_o
are defined in the axolotl.kernels.lora
module.
We welcome testing of other model architectures and / or PRs to expand our patching logic to be compatible with more of them.
Usage
These optimizations can be enabled in your Axolotl config YAML file. The lora_mlp_kernel
option enables the optimized MLP path, while lora_qkv_kernel
and lora_o_kernel
enable the fused query-key-value projection and optimized output projection, respectively.
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
- Note: Set
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1
to enable memory-efficient attention on AMD GPUs
- Note: Set
- Targeted LoRA adapters cannot use Dropout
- This may limit model expressivity / cause overfitting
- Targeted LoRA adapters cannot have bias terms
- This may limit model expressivity
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to be re-finetuned without these features in order to be useful.
Implementation details
Custom autograd functions
The LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the LoRA and base weight computations together and provides a single, efficient backward pass for the entire MLP block.
For attention components, similar optimizations are provided through a function that handles the query, key, and value projections, and a function that handles the output projection. They are designed to work with the existing transformers
attention implementation via some monkey-patching logic.
Triton kernels
Two activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for improved speed and memory performance. These kernels handle both the forward and backward passes.
Integration
The custom autograd functions and Triton kernels are designed to work together. The autograd function manages the high-level computation flow and gradient tracking, while calling the Triton kernels for the activation function computation. During the backward pass, the kernel computes both the activation output and the required gradients, which the autograd function then uses to compute the final gradients for the entire computation path.
Future Work
- Support for additional model architectures
- Support for the FSDP setting
- Support for dropout and bias
- Additional operator fusions