kernels.lora

kernels.lora

Module for definition of Low-Rank Adaptation (LoRA) Triton kernels.

See “LoRA: Low-Rank Adaptation of Large Language Models” (https://arxiv.org/abs/2106.09685).

Credit to unsloth (https://unsloth.ai/) for inspiration for this implementation.

Classes

Name Description
LoRA_MLP Optimized LoRA MLP implementation.
LoRA_O Optimized LoRA implementation for output projection.
LoRA_QKV Optimized LoRA QKV implementation with quantization support.

LoRA_MLP

kernels.lora.LoRA_MLP()

Optimized LoRA MLP implementation.

Methods

Name Description
backward Performs backward pass computation for LoRA MLP.
forward Forward pass for LoRA MLP.
backward
kernels.lora.LoRA_MLP.backward(ctx, grad_output)

Performs backward pass computation for LoRA MLP.

Parameters
Name Type Description Default
ctx torch.autograd.function.FunctionCtx Context object storing tensors saved during forward pass required
grad_output torch.Tensor Gradient of loss with respect to layer output required
Returns
Name Type Description
torch.Tensor | None Tuple containing gradients for all inputs from forward pass:
None - Input gradient tensor (or None)
None - None for weights/quantization states
torch.Tensor | None - LoRA A/B matrix gradients (or None)
torch.Tensor | None - None for scaling factors
None - None for activation functions and flags
forward
kernels.lora.LoRA_MLP.forward(
    ctx,
    X,
    gate_weight,
    gate_quant,
    gate_A,
    gate_B,
    gate_scale,
    up_weight,
    up_quant,
    up_A,
    up_B,
    up_scale,
    down_weight,
    down_quant,
    down_A,
    down_B,
    down_scale,
    activation_fn,
    activation_fn_backward,
    inplace=True,
)

Forward pass for LoRA MLP.

Parameters
Name Type Description Default
ctx Autograd context required
X torch.Tensor Input features required
gate_weight torch.Tensor Gate projection weight required
gate_quant object | None Gate quantization state required
gate_A torch.Tensor | None Gate LoRA A matrix required
gate_B torch.Tensor | None Gate LoRA B matrix required
gate_scale float Gate LoRA scale required
up_weight torch.Tensor Up-projection weight required
up_quant object | None Up-projection quantization state required
up_A torch.Tensor | None Up-projection LoRA A matrix required
up_B torch.Tensor | None Up-projection LoRA B matrix required
up_scale float Up-projection LoRA scale required
down_weight torch.Tensor Down-projection weight required
down_quant object | None Down-projection quantization state required
down_A torch.Tensor | None Down-projection LoRA A matrix required
down_B torch.Tensor | None Down-projection LoRA B matrix required
down_scale float Down-projection LoRA scale required
activation_fn Callable Forward activation function required
activation_fn_backward Callable Backward activation function required
inplace bool | None Whether to perform operations in-place True
Returns
Name Type Description
torch.Tensor Output transformed by multi-layer perceptron and activation function

LoRA_O

kernels.lora.LoRA_O()

Optimized LoRA implementation for output projection.

Methods

Name Description
backward Backward pass computing gradients for LoRA output projection.
forward Forward pass for output projection with LoRA.
backward
kernels.lora.LoRA_O.backward(ctx, dY)

Backward pass computing gradients for LoRA output projection.

Parameters
Name Type Description Default
ctx torch.autograd.function.FunctionCtx Autograd context required
dY torch.Tensor Gradient of loss with respect to output required
Returns
Name Type Description
tuple[torch.Tensor, None, None, torch.Tensor | None, torch.Tensor | None, None] Tuple containing gradients for all forward inputs
forward
kernels.lora.LoRA_O.forward(ctx, X, W, W_quant, A, B, S)

Forward pass for output projection with LoRA.

Parameters
Name Type Description Default
ctx torch.autograd.function.FunctionCtx Autograd context required
X torch.Tensor Input tensor required
W torch.Tensor Output projection weight required
W_quant QuantState | None Weight quantization state required
A torch.Tensor | None LoRA A matrix required
B torch.Tensor | None LoRA B matrix required
S float LoRA scaling factor required
Returns
Name Type Description
torch.Tensor Output projection tensor

LoRA_QKV

kernels.lora.LoRA_QKV()

Optimized LoRA QKV implementation with quantization support.

Implements efficient computation of query, key, value projections with LoRA, supporting quantization and memory optimization.

Methods

Name Description
backward Backward pass computing gradients for LoRA QKV.
forward Forward pass computing Q, K, V projections with LoRA.
backward
kernels.lora.LoRA_QKV.backward(ctx, q_grad, k_grad, v_grad)

Backward pass computing gradients for LoRA QKV.

Parameters
Name Type Description Default
ctx torch.autograd.function.FunctionCtx Autograd context required
q_grad torch.Tensor Gradient for query projection required
k_grad torch.Tensor Gradient for key projection required
v_grad torch.Tensor Gradient for value projection required
Returns
Name Type Description
tuple[torch.Tensor, None, None, torch.Tensor | None, torch.Tensor | None, None, None, None, torch.Tensor | None, torch.Tensor | None, None, None, None, torch.Tensor | None, torch.Tensor | None, None, None] Tuple containing gradients for all forward inputs
forward
kernels.lora.LoRA_QKV.forward(
    ctx,
    X,
    q_weight,
    q_quant,
    q_A,
    q_B,
    q_scale,
    k_weight,
    k_quant,
    k_A,
    k_B,
    k_scale,
    v_weight,
    v_quant,
    v_A,
    v_B,
    v_scale,
    inplace=True,
)

Forward pass computing Q, K, V projections with LoRA.

Parameters
Name Type Description Default
ctx torch.autograd.function.FunctionCtx Autograd context required
X torch.Tensor Input tensor required
q_weight torch.Tensor Query projection weight required
q_quant QuantState | None Query quantization state required
q_A torch.Tensor | None Query LoRA A matrix required
q_B torch.Tensor | None Query LoRA B matrix required
q_scale float Query LoRA scale required
k_weight torch.Tensor Key projection weight required
k_quant QuantState | None Key quantization state required
k_A torch.Tensor | None Key LoRA A matrix required
k_B torch.Tensor | None Key LoRA B matrix required
k_scale float Key LoRA scale required
v_weight torch.Tensor Value projection weight required
v_quant QuantState | None Value quantization state required
v_A torch.Tensor | None Value LoRA A matrix required
v_B torch.Tensor | None Value LoRA B matrix required
v_scale float Value LoRA scale required
inplace bool Whether to perform operations in-place True
Returns
Name Type Description
tuple[torch.Tensor, torch.Tensor, torch.Tensor] Tuple of (Query, Key, Value) projection tensors

Functions

Name Description
apply_lora_mlp_geglu Applies LoRA to MLP layer with GEGLU activation.
apply_lora_mlp_swiglu Applies LoRA to MLP layer with SwiGLU activation.
apply_lora_o Applies LoRA to output projection layer.
apply_lora_qkv Applies LoRA to compute Query, Key, Value projections.
get_lora_parameters Gets LoRA parameters from a projection module.
matmul_lora Efficient fused matmul + LoRA computation.

apply_lora_mlp_geglu

kernels.lora.apply_lora_mlp_geglu(self, X, inplace=True)

Applies LoRA to MLP layer with GEGLU activation.

Parameters

Name Type Description Default
X torch.Tensor Input tensor for the MLP layer required
inplace bool Whether to perform operations in-place to save memory True

Returns

Name Type Description
torch.Tensor Output tensor after applying LoRA-adapted MLP with GEGLU activation

apply_lora_mlp_swiglu

kernels.lora.apply_lora_mlp_swiglu(self, X, inplace=True)

Applies LoRA to MLP layer with SwiGLU activation.

Parameters

Name Type Description Default
X torch.Tensor Input tensor for the MLP layer required
inplace bool Whether to perform operations in-place to save memory True

Returns

Name Type Description
torch.Tensor Output tensor after applying LoRA-adapted MLP with SwiGLU activation

apply_lora_o

kernels.lora.apply_lora_o(self, X)

Applies LoRA to output projection layer.

Parameters

Name Type Description Default
X torch.Tensor Input tensor required

Returns

Name Type Description
torch.Tensor Transformed output tensor

apply_lora_qkv

kernels.lora.apply_lora_qkv(self, X, inplace=True)

Applies LoRA to compute Query, Key, Value projections.

Parameters

Name Type Description Default
X torch.Tensor Input tensor required
inplace bool Whether to perform operations in-place True

Returns

Name Type Description
tuple[torch.Tensor, torch.Tensor, torch.Tensor] Tuple of (Query, Key, Value) projection tensors

get_lora_parameters

kernels.lora.get_lora_parameters(proj)

Gets LoRA parameters from a projection module.

Parameters

Name Type Description Default
proj nn.Module The projection module to extract parameters from. required

Returns

Name Type Description
torch.Tensor A tuple containing the base weight matrix, quantization state, LoRA A matrix,
QuantState | None LoRA B matrix, and scaling factor. States and matrices may be None if not
torch.Tensor | None available.

matmul_lora

kernels.lora.matmul_lora(X, W, W_quant, A, B, s, out=None)

Efficient fused matmul + LoRA computation.

Parameters

Name Type Description Default
X torch.Tensor Input tensor [*, in_features] required
W torch.Tensor Base weight matrix [out_features, in_features] required
W_quant QuantState Quantization state for W required
A torch.Tensor LoRA A matrix [rank, in_features] required
B torch.Tensor LoRA B matrix [out_features, rank] required
s float LoRA scaling factor required
out torch.Tensor | None Optional output tensor for inplace operations None

Returns

Name Type Description
torch.Tensor Result of X @ W + X @ A @ B