monkeypatch.llama_attn_hijack_flash

monkeypatch.llama_attn_hijack_flash

Flash attention monkey patch for llama model

Classes

Name Description
FusedAttention Fused QKV Attention layer for incrementally improved training efficiency
LlamaDecoderLayer patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens

FusedAttention

monkeypatch.llama_attn_hijack_flash.FusedAttention(self, config, q, k, v, o)

Fused QKV Attention layer for incrementally improved training efficiency

LlamaDecoderLayer

monkeypatch.llama_attn_hijack_flash.LlamaDecoderLayer()

patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens

Methods

Name Description
forward
forward
monkeypatch.llama_attn_hijack_flash.LlamaDecoderLayer.forward(
    hidden_states,
    attention_mask=None,
    position_ids=None,
    past_key_value=None,
    output_attentions=False,
    use_cache=False,
    padding_mask=None,
    cu_seqlens=None,
    max_seqlen=None,
)
Parameters
Name Type Description Default
hidden_states torch.FloatTensor input to the layer of shape (batch, seq_len, embed_dim) required
attention_mask torch.FloatTensor, optional attention mask of size (batch, 1, tgt_len, src_len) where padding elements are indicated by very large negative values. None
output_attentions bool, optional Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail. False
use_cache bool, optional If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values). False
past_key_value Tuple(torch.FloatTensor), optional cached past key and value projection states None

Functions

Name Description
flashattn_forward Input shape: Batch x Time x Channel
flashattn_forward_with_s2attn Input shape: Batch x Time x Channel
generate_qkv

flashattn_forward

monkeypatch.llama_attn_hijack_flash.flashattn_forward(
    self,
    hidden_states,
    attention_mask=None,
    position_ids=None,
    past_key_value=None,
    output_attentions=False,
    use_cache=False,
    padding_mask=None,
    cu_seqlens=None,
    max_seqlen=None,
)

Input shape: Batch x Time x Channel

attention_mask: [bsz, q_len]

flashattn_forward_with_s2attn

monkeypatch.llama_attn_hijack_flash.flashattn_forward_with_s2attn(
    self,
    hidden_states,
    attention_mask=None,
    position_ids=None,
    past_key_value=None,
    output_attentions=False,
    use_cache=False,
    padding_mask=None,
    cu_seqlens=None,
    max_seqlen=None,
)

Input shape: Batch x Time x Channel

From: https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py

attention_mask: [bsz, q_len]

cu_seqlens will be ignored if provided max_seqlen will be ignored if provided

generate_qkv

monkeypatch.llama_attn_hijack_flash.generate_qkv(
    q,
    k,
    v,
    query_padding_mask=None,
    key_padding_mask=None,
    kvpacked=False,
    qkvpacked=False,
)

Parameters

Name Type Description Default
q (batch_size, seqlen_q, nheads, d) required
k (batch_size, seqlen_k, nheads_k, d) required
v (batch_size, seqlen_k, nheads_k, d) required
query_padding_mask (batch_size, seqlen), bool None
key_padding_mask (batch_size, seqlen), bool None