-
Notifications
You must be signed in to change notification settings - Fork 31.4k
Gemma3 fixes #41572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gemma3 fixes #41572
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall happy with the changes, imo we should move the training mode check to the config test instead + check with run-slow on our CIs just to be sure
tests/test_modeling_common.py
Outdated
| # Check it can run in training mode | ||
| if check_forward_in_train: | ||
| model.train() | ||
| _ = model(**second_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make more sense on the flash_attn_from_config test (with the new kwarg and making it a default to check for training)? This is still a bit weird to be in this test after another thought --> we only want to check inference equivalence tbh, the config test is more general and checks whether things break.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like a good idea -- that test is already doing a fwd in train mode. Changing it.
| config, attn_implementation=attn_implementation, dtype=torch.bfloat16 | ||
| ).to(torch_device) | ||
| if test_fwd_in_train: | ||
| fa_model = fa_model.train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a small comment here to clarify that it is indeed different, e.g. dropout? Otherwise, lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, can we add another explaining that we set train mode because it's strictly harder than eval? I.e. if it works in train, it works in eval but not necessarily the other way around. Just because it's not obvious why we would set train mode here by default otherwise
Sorry for being annoying but did not get it from first glance
|
run-slow: gemma3 |
|
This comment contains run-slow, running the specified jobs: models: ['models/gemma3'] |
|
Even better than main CI ❤️ feel free to merge after adding a small comment to why train vs eval |
| # Check it can run in training mode | ||
| model.train() | ||
| _ = model(**second_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed from the name of the test it does not seem necesary
| config, attn_implementation=attn_implementation, dtype=torch.bfloat16 | ||
| ).to(torch_device) | ||
| if test_fwd_in_train: | ||
| fa_model = fa_model.train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, can we add another explaining that we set train mode because it's strictly harder than eval? I.e. if it works in train, it works in eval but not necessarily the other way around. Just because it's not obvious why we would set train mode here by default otherwise
Sorry for being annoying but did not get it from first glance
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gemma3 |
* Multiple device error fix * FA2 equivalence fix * Move the train fwd in cfg test * Style * Added comment * Made the comment more clear
* Multiple device error fix * FA2 equivalence fix * Move the train fwd in cfg test * Style * Added comment * Made the comment more clear
This PR fixes three things in
gemma3:torch.wheretakes some of its coefficients from a tensor that is not on the right device and is a full_like, so we just replace it with the filling elementflash_attn_inference_equivalencewhich is due to the model needing more parameters than are generated by defualt. To avoid this, we add a flag that specifies if we need to check the forward pass with training or not, and make this check default for both and left padding (cc. @vasqu )flash_attn_from_configwas failing for the same reasons (token_type_idsis required as a model input when training) so I added a.eval()to avoid this. It does not seem the model needs to be in train mode for this test, but I can also add an option to the test to only call.eval()if a flag is passed