monkeypatch.mistral_attn_hijack_flash

monkeypatch.mistral_attn_hijack_flash

Flash attention monkey patch for mistral model

Classes

Name Description
MistralDecoderLayer patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens

MistralDecoderLayer

monkeypatch.mistral_attn_hijack_flash.MistralDecoderLayer()

patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens

Methods

Name Description
forward
forward
monkeypatch.mistral_attn_hijack_flash.MistralDecoderLayer.forward(
    hidden_states,
    attention_mask=None,
    position_ids=None,
    past_key_value=None,
    output_attentions=False,
    use_cache=False,
    cu_seqlens=None,
    max_seqlen=None,
)
Parameters
Name Type Description Default
hidden_states torch.FloatTensor input to the layer of shape (batch, seq_len, embed_dim) required
attention_mask torch.FloatTensor, optional attention mask of size (batch, 1, tgt_len, src_len) where padding elements are indicated by very large negative values. None
output_attentions bool, optional Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail. False
use_cache bool, optional If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values). False
past_key_value Tuple(torch.FloatTensor), optional cached past key and value projection states None

Functions

Name Description
generate_qkv

generate_qkv

monkeypatch.mistral_attn_hijack_flash.generate_qkv(
    q,
    k,
    v,
    query_padding_mask=None,
    key_padding_mask=None,
    kvpacked=False,
    qkvpacked=False,
)

Parameters

Name Type Description Default
q (batch_size, seqlen_q, nheads, d) required
k (batch_size, seqlen_k, nheads_k, d) required
v (batch_size, seqlen_k, nheads_k, d) required
query_padding_mask (batch_size, seqlen), bool None
key_padding_mask (batch_size, seqlen), bool None