monkeypatch.mistral_attn_hijack_flash
monkeypatch.mistral_attn_hijack_flash
Flash attention monkey patch for mistral model
Classes
MistralDecoderLayer |
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens |
MistralDecoderLayer
monkeypatch.mistral_attn_hijack_flash.MistralDecoderLayer()
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
Methods
forward
monkeypatch.mistral_attn_hijack_flash.MistralDecoderLayer.forward(
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
output_attentions=False,
use_cache=False,
cu_seqlens=None,
max_seqlen=None,
)
Parameters
hidden_states |
torch.FloatTensor |
input to the layer of shape (batch, seq_len, embed_dim) |
required |
attention_mask |
torch.FloatTensor , optional |
attention mask of size (batch, 1, tgt_len, src_len) where padding elements are indicated by very large negative values. |
None |
output_attentions |
bool , optional |
Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail. |
False |
use_cache |
bool , optional |
If set to True , past_key_values key value states are returned and can be used to speed up decoding (see past_key_values ). |
False |
past_key_value |
Tuple(torch.FloatTensor) , optional |
cached past key and value projection states |
None |
Functions
generate_qkv
monkeypatch.mistral_attn_hijack_flash.generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
)
Parameters
q |
|
(batch_size, seqlen_q, nheads, d) |
required |
k |
|
(batch_size, seqlen_k, nheads_k, d) |
required |
v |
|
(batch_size, seqlen_k, nheads_k, d) |
required |
query_padding_mask |
|
(batch_size, seqlen), bool |
None |
key_padding_mask |
|
(batch_size, seqlen), bool |
None |