-
Notifications
You must be signed in to change notification settings - Fork 31.4k
fix: loss computation after embeddings resize - mllama #36840
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: loss computation after embeddings resize - mllama #36840
Conversation
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
| return output | ||
|
|
||
| return CausalLMOutputWithPast( | ||
| loss=loss, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still left rest of return types in the class , for minimal changes as you mentioned there would be a massive refactor @zucchini-nlp . So I just moved loss computation to ConditionalGenerationClass . However with that change, there may not be a need for this MllamaForCausalLM class at all, as basically its just adding logits, which can also be moved to ConditionalGenerationClass and just the MllamaText Model class can be used directly from ConditionalGenerationClass
Will leave that to you as you think of the refactor. Or if you want me to clean it up, can do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, to leave as is for users who load only the LLM part
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add it back for users that use this class directly, but will not pass labels from GenerationClass to avoid exception
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
| out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] | ||
| torch.testing.assert_close(out_embeds, out_ids) | ||
|
|
||
| def test_resize_embeddings_results_in_successful_loss(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test for reported bug that would fail earlier and now pass
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks a lot! This looks better to me as a temporary workaround until refactoring. I left some comments about removing loss from causalLM completely, but otherwise LGTM
| return output | ||
|
|
||
| return CausalLMOutputWithPast( | ||
| loss=loss, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, to leave as is for users who load only the LLM part
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
|
@zucchini-nlp all tests have passed and comments addressed . Thank you for the review. |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, thanks!
|
Thank you @zucchini-nlp . Can the PR be merged soon as well ? It is actually blocking a use case we have |
|
Yep, merging, sorry |
…6840) * move loss to generation class Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * code cleanup Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * test for resize and loss computation Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix tests Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix:test for resize and loss Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix resize embedding mllama test Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * review changes Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> --------- Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
What does this PR do?
background: #36591
Fixes loss computation when vocab is resized by resize_embeddings. Only vocab size of parent conditionalgenerationclass is modified, hence loss has to be calculated there. Chosen approach until bigger refactor. Solution discussed with @zucchini-nlp in above PR
Fixes # (issue)
#36590
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@zucchini-nlp as discussed solution