monkeypatch.transformers_fa_utils
monkeypatch.transformers_fa_utils
see https://github.com/huggingface/transformers/pull/35834
Functions
Name | Description |
---|---|
fixed_fa_peft_integration_check | PEFT usually casts the layer norms in float32 for training stability reasons |
fixed_fa_peft_integration_check
monkeypatch.transformers_fa_utils.fixed_fa_peft_integration_check(
query,
key,
value,=None,
target_dtype=None,
preferred_dtype )
PEFT usually casts the layer norms in float32 for training stability reasons therefore the input hidden states gets silently casted in float32. Hence, we need cast them back in float16 / bfloat16 just to be sure everything works as expected. This might slowdown training & inference so it is recommended to not cast the LayerNorms!
Parameters
Name | Type | Description | Default |
---|---|---|---|
query | torch.Tensor |
Input query states to be passed to Flash Attention API | required |
key | torch.Tensor |
Input key states to be passed to Flash Attention API | required |
value | torch.Tensor |
Input value states to be passed to Flash Attention API | required |
target_dtype | torch.dtype , optional |
The dtype to convert the attention tensors to. Conversion can be ignored by not providing the target dtype. | None |
preferred_dtype | torch.dtype , optional |
The preferred dtype to convert the attention tensors to regardless of the target dtype. | None |