monkeypatch.stablelm_attn_hijack_flash

monkeypatch.stablelm_attn_hijack_flash

PyTorch StableLM Epoch model.

Functions

Name Description
repeat_kv This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
rotate_half Rotates half the hidden dims of the input.

repeat_kv

monkeypatch.stablelm_attn_hijack_flash.repeat_kv(hidden_states, n_rep)

This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)

rotate_half

monkeypatch.stablelm_attn_hijack_flash.rotate_half(x)

Rotates half the hidden dims of the input.