monkeypatch.transformers_fa_utils

monkeypatch.transformers_fa_utils

see https://github.com/huggingface/transformers/pull/35834

Functions

Name Description
fixed_fa_peft_integration_check PEFT usually casts the layer norms in float32 for training stability reasons

fixed_fa_peft_integration_check

monkeypatch.transformers_fa_utils.fixed_fa_peft_integration_check(
    query,
    key,
    value,
    target_dtype=None,
    preferred_dtype=None,
)

PEFT usually casts the layer norms in float32 for training stability reasons therefore the input hidden states gets silently casted in float32. Hence, we need cast them back in float16 / bfloat16 just to be sure everything works as expected. This might slowdown training & inference so it is recommended to not cast the LayerNorms!

Parameters

Name Type Description Default
query torch.Tensor Input query states to be passed to Flash Attention API required
key torch.Tensor Input key states to be passed to Flash Attention API required
value torch.Tensor Input value states to be passed to Flash Attention API required
target_dtype torch.dtype, optional The dtype to convert the attention tensors to. Conversion can be ignored by not providing the target dtype. None
preferred_dtype torch.dtype, optional The preferred dtype to convert the attention tensors to regardless of the target dtype. None