kernels.swiglu
kernels.swiglu
Module for definition of SwiGLU Triton kernels.
See “GLU Variants Improve Transformer” (https://arxiv.org/abs/2002.05202).
Credit to unsloth
(https://unsloth.ai/) for inspiration for this implementation.
Functions
Name | Description |
---|---|
swiglu_backward | SwiGLU backward pass using in-place operations. |
swiglu_forward | SwiGLU forward pass. Computes SwiGLU activation: x * sigmoid(x) * up , where |
swiglu_backward
kernels.swiglu.swiglu_backward(grad_output, gate, up)
SwiGLU backward pass using in-place operations.
Parameters
Name | Type | Description | Default |
---|---|---|---|
grad_output | torch.Tensor | Gradient of loss with respect to output, shape [batch, seq_len, hidden_dim] . |
required |
gate | torch.Tensor | Gate tensor from forward pass, shape [batch, seq_len, hidden_dim] . |
required |
up | torch.Tensor | Up-projection tensor from forward pass, shape [batch, seq_len, hidden_dim] . |
required |
Returns
Name | Type | Description |
---|---|---|
tuple[torch.Tensor, torch.Tensor, torch.Tensor] | Tuple containing: - Forward pass output (h ) - Gradient with respect to gate (df ) - Gradient with respect to up-projection (de ) |
swiglu_forward
kernels.swiglu.swiglu_forward(gate, up)
SwiGLU forward pass. Computes SwiGLU activation: x * sigmoid(x) * up
, where
x
is the gate tensor.
Parameters
Name | Type | Description | Default |
---|---|---|---|
gate | torch.Tensor | Input gate tensor of shape [batch, seq_len, hidden_dim] . |
required |
up | torch.Tensor | Up-projection tensor of shape [batch, seq_len, hidden_dim] . |
required |
Returns
Name | Type | Description |
---|---|---|
torch.Tensor | Output tensor of shape [batch, seq_len, hidden_dim] . |