monkeypatch.lora_kernels

monkeypatch.lora_kernels

Module for patching custom LoRA Triton kernels and torch.autograd functions.

Functions

Name Description
apply_lora_kernel_patches Applies optimized Triton kernel patches to a PEFT model.
get_attention_cls_from_config Get the appropriate attention class by inspecting the model config.
original_apply_o Original implementation of output projection without optimizations.
original_apply_qkv Original implementation of QKV projection without optimizations.
patch_self_attn_lora Given an axolotl config, this method patches the inferred attention class forward

apply_lora_kernel_patches

monkeypatch.lora_kernels.apply_lora_kernel_patches(model, cfg)

Applies optimized Triton kernel patches to a PEFT model.

Patches a PEFT model with optimized implementations for MLP and attention computations. The optimizations include custom Triton kernels for activation functions and specialized autograd functions for LoRA computations.

Parameters

Name Type Description Default
model PeftModelForCausalLM A PEFT model to be patched with optimized kernels. required
cfg DictDefault Dictionary mapping axolotl config keys to values. required

Returns

Name Type Description
PeftModelForCausalLM PeftModelForCausalLM The patched model with optimized kernels.

Raises

Name Type Description
TypeError If the provided model is not a PeftModelForCausalLM.
NotImplementedError If the model type is not supported.
AssertionError If multiple adapters are active (currently unsupported).

Note

The optimizations require LoRA adapters with no dropout and no bias terms. The function will skip patching if these conditions aren’t met.

get_attention_cls_from_config

monkeypatch.lora_kernels.get_attention_cls_from_config(cfg)

Get the appropriate attention class by inspecting the model config. Uses dynamic import to support any model architecture that follows the standard transformers naming convention.

Parameters

Name Type Description Default
cfg DictDefault Dictionary mapping axolotl config keys to values. required

Returns

Name Type Description
Type[nn.Module] The appropriate attention class for the model.

Raises

Name Type Description
ValueError If base_model not specified or attention class cannot be imported
ImportError If the model module or attention class doesn’t exist

original_apply_o

monkeypatch.lora_kernels.original_apply_o(self, hidden_states)

Original implementation of output projection without optimizations.

Parameters

Name Type Description Default
self nn.Module The attention module instance. required
hidden_states torch.Tensor Input tensor of shape [batch_size, seq_len, hidden_dim]`. required

Returns

Name Type Description
torch.Tensor The output projection result.

original_apply_qkv

monkeypatch.lora_kernels.original_apply_qkv(self, hidden_states)

Original implementation of QKV projection without optimizations.

Parameters

Name Type Description Default
self nn.Module The attention module instance. required
hidden_states torch.Tensor Input tensor of shape [batch_size, seq_len, hidden_dim]. required

Returns

Name Type Description
tuple[torch.Tensor, torch.Tensor, torch.Tensor] A tuple (query_states, key_states, value_states) containing the projected states for query, key, and value.

patch_self_attn_lora

monkeypatch.lora_kernels.patch_self_attn_lora(cfg)

Given an axolotl config, this method patches the inferred attention class forward pass with optimized LoRA implementations.

It modifies the attention class to use optimized QKV and output projections. The original implementation is preserved and can be restored if needed.

Parameters

Name Type Description Default
cfg DictDefault Dictionary mapping axolotl config keys to values. required

Raises

Name Type Description
AssertionError If the required code blocks are not found in the attention implementation.