Skip to content

Conversation

@Ssukriti
Copy link
Contributor

What does this PR do?

background: #36591

Fixes loss computation when vocab is resized by resize_embeddings. Only vocab size of parent conditionalgenerationclass is modified, hence loss has to be calculated there. Chosen approach until bigger refactor. Solution discussed with @zucchini-nlp in above PR

Fixes # (issue)
#36590

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@zucchini-nlp as discussed solution

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
return output

return CausalLMOutputWithPast(
loss=loss,
Copy link
Contributor Author

@Ssukriti Ssukriti Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still left rest of return types in the class , for minimal changes as you mentioned there would be a massive refactor @zucchini-nlp . So I just moved loss computation to ConditionalGenerationClass . However with that change, there may not be a need for this MllamaForCausalLM class at all, as basically its just adding logits, which can also be moved to ConditionalGenerationClass and just the MllamaText Model class can be used directly from ConditionalGenerationClass

Will leave that to you as you think of the refactor. Or if you want me to clean it up, can do

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, to leave as is for users who load only the LLM part

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add it back for users that use this class directly, but will not pass labels from GenerationClass to avoid exception

@Ssukriti Ssukriti changed the title fix loss computation after embeddings resize fix: loss computation after embeddings resize - mllama Mar 19, 2025
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
@Ssukriti Ssukriti marked this pull request as ready for review March 19, 2025 23:47
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)

def test_resize_embeddings_results_in_successful_loss(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test for reported bug that would fail earlier and now pass

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks a lot! This looks better to me as a temporary workaround until refactoring. I left some comments about removing loss from causalLM completely, but otherwise LGTM

return output

return CausalLMOutputWithPast(
loss=loss,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, to leave as is for users who load only the LLM part

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
@Ssukriti
Copy link
Contributor Author

@zucchini-nlp all tests have passed and comments addressed . Thank you for the review.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thanks!

@Ssukriti
Copy link
Contributor Author

Thank you @zucchini-nlp . Can the PR be merged soon as well ? It is actually blocking a use case we have

@zucchini-nlp
Copy link
Member

Yep, merging, sorry

@zucchini-nlp zucchini-nlp merged commit 90e2df5 into huggingface:main Mar 21, 2025
12 checks passed
@Ssukriti Ssukriti deleted the test_fix_loss_computation branch April 10, 2025 23:43
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
…6840)

* move loss to generation class

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* code cleanup

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* test for resize and loss computation

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix tests

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix:test for resize and loss

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix resize embedding mllama test

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* review changes

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

---------

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants