Skip to content

Conversation

@remi-or
Copy link
Collaborator

@remi-or remi-or commented Oct 14, 2025

This PR fixes three things in gemma3:

  • a multiple-device error where torch.where takes some of its coefficients from a tensor that is not on the right device and is a full_like, so we just replace it with the filling element
  • an error in the flash_attn_inference_equivalence which is due to the model needing more parameters than are generated by defualt. To avoid this, we add a flag that specifies if we need to check the forward pass with training or not, and make this check default for both and left padding (cc. @vasqu )
  • the test flash_attn_from_config was failing for the same reasons (token_type_ids is required as a model input when training) so I added a .eval() to avoid this. It does not seem the model needs to be in train mode for this test, but I can also add an option to the test to only call .eval() if a flag is passed

@remi-or remi-or requested a review from vasqu October 14, 2025 11:02
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall happy with the changes, imo we should move the training mode check to the config test instead + check with run-slow on our CIs just to be sure

Comment on lines 3124 to 3127
# Check it can run in training mode
if check_forward_in_train:
model.train()
_ = model(**second_inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make more sense on the flash_attn_from_config test (with the new kwarg and making it a default to check for training)? This is still a bit weird to be in this test after another thought --> we only want to check inference equivalence tbh, the config test is more general and checks whether things break.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a good idea -- that test is already doing a fwd in train mode. Changing it.

config, attn_implementation=attn_implementation, dtype=torch.bfloat16
).to(torch_device)
if test_fwd_in_train:
fa_model = fa_model.train()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a small comment here to clarify that it is indeed different, e.g. dropout? Otherwise, lgtm

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, can we add another explaining that we set train mode because it's strictly harder than eval? I.e. if it works in train, it works in eval but not necessarily the other way around. Just because it's not obvious why we would set train mode here by default otherwise
Sorry for being annoying but did not get it from first glance

@vasqu
Copy link
Contributor

vasqu commented Oct 14, 2025

run-slow: gemma3

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/gemma3']
quantizations: [] ...

@vasqu
Copy link
Contributor

vasqu commented Oct 14, 2025

Even better than main CI ❤️ feel free to merge after adding a small comment to why train vs eval

Comment on lines -3117 to -3119
# Check it can run in training mode
model.train()
_ = model(**second_inputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed from the name of the test it does not seem necesary

config, attn_implementation=attn_implementation, dtype=torch.bfloat16
).to(torch_device)
if test_fwd_in_train:
fa_model = fa_model.train()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, can we add another explaining that we set train mode because it's strictly harder than eval? I.e. if it works in train, it works in eval but not necessarily the other way around. Just because it's not obvious why we would set train mode here by default otherwise
Sorry for being annoying but did not get it from first glance

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma3

@Cyrilvallez Cyrilvallez merged commit 9e4199e into huggingface:main Oct 14, 2025
14 of 25 checks passed
i3hz pushed a commit to i3hz/transformers that referenced this pull request Oct 15, 2025
* Multiple device error fix

* FA2 equivalence fix

* Move the train fwd in cfg test

* Style

* Added comment

* Made the comment more clear
ngazagna-qc pushed a commit to ngazagna-qc/transformers that referenced this pull request Oct 23, 2025
* Multiple device error fix

* FA2 equivalence fix

* Move the train fwd in cfg test

* Style

* Added comment

* Made the comment more clear
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants