kernels.lora
kernels.lora
Module for definition of Low-Rank Adaptation (LoRA) Triton kernels.
See “LoRA: Low-Rank Adaptation of Large Language Models”
(https://arxiv.org/abs/2106.09685).
Credit to unsloth
(https://unsloth.ai/) for inspiration for this implementation.
Classes
LoRA_MLP |
Optimized LoRA MLP implementation. |
LoRA_O |
Optimized LoRA implementation for output projection. |
LoRA_QKV |
Optimized LoRA QKV implementation with quantization support. |
LoRA_MLP
Optimized LoRA MLP implementation.
Methods
backward |
Performs backward pass computation for LoRA MLP. |
forward |
Forward pass for LoRA MLP. |
backward
kernels.lora.LoRA_MLP.backward(ctx, grad_output)
Performs backward pass computation for LoRA MLP.
Parameters
ctx |
torch.autograd.function.FunctionCtx |
Context object storing tensors saved during forward pass |
required |
grad_output |
torch.Tensor |
Gradient of loss with respect to layer output |
required |
Returns
|
torch.Tensor | None |
Tuple containing gradients for all inputs from forward pass: |
|
None |
- Input gradient tensor (or None ) |
|
None |
- None for weights/quantization states |
|
torch.Tensor | None |
- LoRA A/B matrix gradients (or None ) |
|
torch.Tensor | None |
- None for scaling factors |
|
None |
- None for activation functions and flags |
forward
kernels.lora.LoRA_MLP.forward(
ctx,
X,
gate_weight,
gate_quant,
gate_A,
gate_B,
gate_scale,
up_weight,
up_quant,
up_A,
up_B,
up_scale,
down_weight,
down_quant,
down_A,
down_B,
down_scale,
activation_fn,
activation_fn_backward,
inplace=True,
)
Forward pass for LoRA MLP.
Parameters
ctx |
|
Autograd context |
required |
X |
torch.Tensor |
Input features |
required |
gate_weight |
torch.Tensor |
Gate projection weight |
required |
gate_quant |
object | None |
Gate quantization state |
required |
gate_A |
torch.Tensor | None |
Gate LoRA A matrix |
required |
gate_B |
torch.Tensor | None |
Gate LoRA B matrix |
required |
gate_scale |
float |
Gate LoRA scale |
required |
up_weight |
torch.Tensor |
Up-projection weight |
required |
up_quant |
object | None |
Up-projection quantization state |
required |
up_A |
torch.Tensor | None |
Up-projection LoRA A matrix |
required |
up_B |
torch.Tensor | None |
Up-projection LoRA B matrix |
required |
up_scale |
float |
Up-projection LoRA scale |
required |
down_weight |
torch.Tensor |
Down-projection weight |
required |
down_quant |
object | None |
Down-projection quantization state |
required |
down_A |
torch.Tensor | None |
Down-projection LoRA A matrix |
required |
down_B |
torch.Tensor | None |
Down-projection LoRA B matrix |
required |
down_scale |
float |
Down-projection LoRA scale |
required |
activation_fn |
Callable |
Forward activation function |
required |
activation_fn_backward |
Callable |
Backward activation function |
required |
inplace |
bool | None |
Whether to perform operations in-place |
True |
Returns
|
torch.Tensor |
Output transformed by multi-layer perceptron and activation function |
LoRA_O
Optimized LoRA implementation for output projection.
Methods
backward |
Backward pass computing gradients for LoRA output projection. |
forward |
Forward pass for output projection with LoRA. |
backward
kernels.lora.LoRA_O.backward(ctx, dY)
Backward pass computing gradients for LoRA output projection.
Parameters
ctx |
torch.autograd.function.FunctionCtx |
Autograd context |
required |
dY |
torch.Tensor |
Gradient of loss with respect to output |
required |
Returns
|
tuple[torch.Tensor, None, None, torch.Tensor | None, torch.Tensor | None, None] |
Tuple containing gradients for all forward inputs |
forward
kernels.lora.LoRA_O.forward(ctx, X, W, W_quant, A, B, S)
Forward pass for output projection with LoRA.
Parameters
ctx |
torch.autograd.function.FunctionCtx |
Autograd context |
required |
X |
torch.Tensor |
Input tensor |
required |
W |
torch.Tensor |
Output projection weight |
required |
W_quant |
QuantState | None |
Weight quantization state |
required |
A |
torch.Tensor | None |
LoRA A matrix |
required |
B |
torch.Tensor | None |
LoRA B matrix |
required |
S |
float |
LoRA scaling factor |
required |
Returns
|
torch.Tensor |
Output projection tensor |
LoRA_QKV
Optimized LoRA QKV implementation with quantization support.
Implements efficient computation of query, key, value projections with LoRA,
supporting quantization and memory optimization.
Methods
backward |
Backward pass computing gradients for LoRA QKV. |
forward |
Forward pass computing Q, K, V projections with LoRA. |
backward
kernels.lora.LoRA_QKV.backward(ctx, q_grad, k_grad, v_grad)
Backward pass computing gradients for LoRA QKV.
Parameters
ctx |
torch.autograd.function.FunctionCtx |
Autograd context |
required |
q_grad |
torch.Tensor |
Gradient for query projection |
required |
k_grad |
torch.Tensor |
Gradient for key projection |
required |
v_grad |
torch.Tensor |
Gradient for value projection |
required |
Returns
|
tuple[torch.Tensor, None, None, torch.Tensor | None, torch.Tensor | None, None, None, None, torch.Tensor | None, torch.Tensor | None, None, None, None, torch.Tensor | None, torch.Tensor | None, None, None] |
Tuple containing gradients for all forward inputs |
forward
kernels.lora.LoRA_QKV.forward(
ctx,
X,
q_weight,
q_quant,
q_A,
q_B,
q_scale,
k_weight,
k_quant,
k_A,
k_B,
k_scale,
v_weight,
v_quant,
v_A,
v_B,
v_scale,
inplace=True,
)
Forward pass computing Q, K, V projections with LoRA.
Parameters
ctx |
torch.autograd.function.FunctionCtx |
Autograd context |
required |
X |
torch.Tensor |
Input tensor |
required |
q_weight |
torch.Tensor |
Query projection weight |
required |
q_quant |
QuantState | None |
Query quantization state |
required |
q_A |
torch.Tensor | None |
Query LoRA A matrix |
required |
q_B |
torch.Tensor | None |
Query LoRA B matrix |
required |
q_scale |
float |
Query LoRA scale |
required |
k_weight |
torch.Tensor |
Key projection weight |
required |
k_quant |
QuantState | None |
Key quantization state |
required |
k_A |
torch.Tensor | None |
Key LoRA A matrix |
required |
k_B |
torch.Tensor | None |
Key LoRA B matrix |
required |
k_scale |
float |
Key LoRA scale |
required |
v_weight |
torch.Tensor |
Value projection weight |
required |
v_quant |
QuantState | None |
Value quantization state |
required |
v_A |
torch.Tensor | None |
Value LoRA A matrix |
required |
v_B |
torch.Tensor | None |
Value LoRA B matrix |
required |
v_scale |
float |
Value LoRA scale |
required |
inplace |
bool |
Whether to perform operations in-place |
True |
Returns
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor] |
Tuple of (Query, Key, Value) projection tensors |
Functions
apply_lora_mlp_geglu
kernels.lora.apply_lora_mlp_geglu(self, X, inplace=True)
Applies LoRA to MLP layer with GEGLU activation.
Parameters
X |
torch.Tensor |
Input tensor for the MLP layer |
required |
inplace |
bool |
Whether to perform operations in-place to save memory |
True |
Returns
|
torch.Tensor |
Output tensor after applying LoRA-adapted MLP with GEGLU activation |
apply_lora_mlp_swiglu
kernels.lora.apply_lora_mlp_swiglu(self, X, inplace=True)
Applies LoRA to MLP layer with SwiGLU activation.
Parameters
X |
torch.Tensor |
Input tensor for the MLP layer |
required |
inplace |
bool |
Whether to perform operations in-place to save memory |
True |
Returns
|
torch.Tensor |
Output tensor after applying LoRA-adapted MLP with SwiGLU activation |
apply_lora_o
kernels.lora.apply_lora_o(self, X)
Applies LoRA to output projection layer.
Parameters
X |
torch.Tensor |
Input tensor |
required |
Returns
|
torch.Tensor |
Transformed output tensor |
apply_lora_qkv
kernels.lora.apply_lora_qkv(self, X, inplace=True)
Applies LoRA to compute Query, Key, Value projections.
Parameters
X |
torch.Tensor |
Input tensor |
required |
inplace |
bool |
Whether to perform operations in-place |
True |
Returns
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor] |
Tuple of (Query, Key, Value) projection tensors |
get_lora_parameters
kernels.lora.get_lora_parameters(proj)
Gets LoRA parameters from a projection module.
Parameters
proj |
nn.Module |
The projection module to extract parameters from. |
required |
Returns
|
torch.Tensor |
A tuple containing the base weight matrix, quantization state, LoRA A matrix, |
|
QuantState | None |
LoRA B matrix, and scaling factor. States and matrices may be None if not |
|
torch.Tensor | None |
available. |
matmul_lora
kernels.lora.matmul_lora(X, W, W_quant, A, B, s, out=None)
Efficient fused matmul + LoRA computation.
Parameters
X |
torch.Tensor |
Input tensor [*, in_features] |
required |
W |
torch.Tensor |
Base weight matrix [out_features, in_features] |
required |
W_quant |
QuantState |
Quantization state for W |
required |
A |
torch.Tensor |
LoRA A matrix [rank, in_features] |
required |
B |
torch.Tensor |
LoRA B matrix [out_features, rank] |
required |
s |
float |
LoRA scaling factor |
required |
out |
torch.Tensor | None |
Optional output tensor for inplace operations |
None |
Returns
|
torch.Tensor |
Result of X @ W + X @ A @ B |