kernels.swiglu

kernels.swiglu

Module for definition of SwiGLU Triton kernels.

See “GLU Variants Improve Transformer” (https://arxiv.org/abs/2002.05202).

Credit to unsloth (https://unsloth.ai/) for inspiration for this implementation.

Functions

Name Description
swiglu_backward SwiGLU backward pass using in-place operations.
swiglu_forward SwiGLU forward pass. Computes SwiGLU activation: x * sigmoid(x) * up, where

swiglu_backward

kernels.swiglu.swiglu_backward(grad_output, gate, up)

SwiGLU backward pass using in-place operations.

Parameters

Name Type Description Default
grad_output torch.Tensor Gradient of loss with respect to output, shape [batch, seq_len, hidden_dim]. required
gate torch.Tensor Gate tensor from forward pass, shape [batch, seq_len, hidden_dim]. required
up torch.Tensor Up-projection tensor from forward pass, shape [batch, seq_len, hidden_dim]. required

Returns

Name Type Description
tuple[torch.Tensor, torch.Tensor, torch.Tensor] Tuple containing: - Forward pass output (h) - Gradient with respect to gate (df) - Gradient with respect to up-projection (de)

swiglu_forward

kernels.swiglu.swiglu_forward(gate, up)

SwiGLU forward pass. Computes SwiGLU activation: x * sigmoid(x) * up, where x is the gate tensor.

Parameters

Name Type Description Default
gate torch.Tensor Input gate tensor of shape [batch, seq_len, hidden_dim]. required
up torch.Tensor Up-projection tensor of shape [batch, seq_len, hidden_dim]. required

Returns

Name Type Description
torch.Tensor Output tensor of shape [batch, seq_len, hidden_dim].