Support of flash attention 2?
I notice that there is no FlashAttention class in the vision tower part (which is fine)
But seems that even when I overwrite the config for the LM part (by config = AutoConfig.from_pretrained(this_repo)
and config.text_config._attn_implementation="flash_attention_2"
), the attention implementation still becomes SDPA.
(And because there is no flash attention for vision tower, overwritting config for whole modelconfig._attn_implementation="flash_attention_2"
will of course throw error)
Is there any way to switch the LM to use flash attention 2?
Can you try indicating attention while loading with model = LlavaForConditionalGeneration.from_pretrained(model_id, attn_implementation={"text_config": "flash_attention_2", "vision_config": "sdpa"}, torch_dtype=torch.float16)
?
LMK if that works, prob the attention is not picked up from sub-config. I am planning to fix all attention related issues in VLMs soon
Thank you for the suggestion! Not sure if it is because my transformers version (4.46.1), the vision tower also don't have sdpa, but changing it to eager works!