monkeypatch.stablelm_attn_hijack_flash
monkeypatch.stablelm_attn_hijack_flash
PyTorch StableLM Epoch model.
Functions
Name | Description |
---|---|
repeat_kv | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
rotate_half | Rotates half the hidden dims of the input. |
repeat_kv
monkeypatch.stablelm_attn_hijack_flash.repeat_kv(hidden_states, n_rep)
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
rotate_half
monkeypatch.stablelm_attn_hijack_flash.rotate_half(x)
Rotates half the hidden dims of the input.