Skip to content

Conversation

@i3hz
Copy link
Contributor

@i3hz i3hz commented Nov 11, 2025

What does this PR do?

Fixes the issue where models use an outdated if self.config._attn_implementation != "flash_attention_2": check.

Models changed - SmolVLM, idefics3 , idefics2

Fixes #42121

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@zucchini-nlp @vasqu

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to fix idefics2 properly or leave it out for now, the others lgtm. Will check with run-slow a bit later

Comment on lines 503 to 505
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove these comments too

from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_bidirectional_mask
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
Copy link
Contributor

@vasqu vasqu Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like _prepare_4d_attention_mask is still used and it will likely cause similar issues at other points, best to completely remove this usage elsewhere too!

See

attention_mask = (
_prepare_4d_attention_mask(attention_mask, latents.dtype, tgt_len=self.n_latents)
if self.config._attn_implementation != "flash_attention_2"
else attention_mask
)

(looks like the flag for that model there also needs to be updated to _supports_flash_attn)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so do I just reset all the changes from idefics2?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea either that or use the create_bidirectional_mask fn here as well if it works; if not, I also appreciate that. Means I need to take a look here :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright I can't run the testing script as I don't have enough vram but is this the correct approach?

@i3hz
Copy link
Contributor Author

i3hz commented Nov 12, 2025

tests seem to be failing so I don't think its the correct one

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: idefics3, smolvlm

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I think there were more models which still check for flash_attention_2, would be nice to batch update all. It can totally go in a separate PR later :)

@zucchini-nlp
Copy link
Member

run-slow: idefics3, smolvlm

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/idefics3", "models/smolvlm"]
quantizations: []

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, let's keep it simple for now. We can address more models in separate PRs

@vasqu
Copy link
Contributor

vasqu commented Nov 12, 2025

tests seem to be failing so I don't think its the correct one

Hmm it might need a dummy embedding of the correct size along the latents. Would leave this for a different PR

@vasqu vasqu merged commit fcea1e1 into huggingface:main Nov 12, 2025
18 checks passed
@vasqu
Copy link
Contributor

vasqu commented Nov 12, 2025

Also thx for all the PR 🤗

@i3hz
Copy link
Contributor Author

i3hz commented Nov 12, 2025

LGTM! I think there were more models which still check for flash_attention_2, would be nice to batch update all. It can totally go in a separate PR later :)

Yeah I didn't want to make a whole lot of changes in a single PR . I find it very confusing , sorry if that wasn't the ideal choice

@i3hz i3hz deleted the flash3 branch November 13, 2025 03:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

kernels-community/flash-attn3 does not work with SmolVLM2

4 participants