monkeypatch.utils
monkeypatch.utils
Shared utils for the monkeypatches
Functions
Name | Description |
---|---|
get_cu_seqlens | generate a cumulative sequence length mask for flash attention using attn mask |
get_cu_seqlens_from_pos_ids | generate a cumulative sequence length mask for flash attention using pos ids |
mask_2d_to_4d | Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len] . |
get_cu_seqlens
monkeypatch.utils.get_cu_seqlens(attn_mask)
generate a cumulative sequence length mask for flash attention using attn mask
get_cu_seqlens_from_pos_ids
monkeypatch.utils.get_cu_seqlens_from_pos_ids(position_ids)
generate a cumulative sequence length mask for flash attention using pos ids
mask_2d_to_4d
=None) monkeypatch.utils.mask_2d_to_4d(mask, dtype, tgt_len
Expands attention_mask from [bsz, seq_len]
to [bsz, 1, tgt_seq_len, src_seq_len]
.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.