monkeypatch.utils

monkeypatch.utils

Shared utils for the monkeypatches

Functions

Name Description
get_cu_seqlens generate a cumulative sequence length mask for flash attention using attn mask
get_cu_seqlens_from_pos_ids generate a cumulative sequence length mask for flash attention using pos ids
mask_2d_to_4d Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len].

get_cu_seqlens

monkeypatch.utils.get_cu_seqlens(attn_mask)

generate a cumulative sequence length mask for flash attention using attn mask

get_cu_seqlens_from_pos_ids

monkeypatch.utils.get_cu_seqlens_from_pos_ids(position_ids)

generate a cumulative sequence length mask for flash attention using pos ids

mask_2d_to_4d

monkeypatch.utils.mask_2d_to_4d(mask, dtype, tgt_len=None)

Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len]. This expansion handles packed sequences so that sequences share the same attention mask integer value when they attend to each other within that sequence. This expansion transforms the mask to lower triangular form to prevent future peeking.