-
Notifications
You must be signed in to change notification settings - Fork 31.4k
fix: update vocab size of language model's config on resize - mllama #36591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: update vocab size of language model's config on resize - mllama #36591
Conversation
|
@zucchini-nlp the Pr is ready for review changes have been made as discussed earlier (in resolved conversation above).
|
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, sorry it took long to come back!
After giving it more thought, I think the easiest solution is to stop passing labels to the language model and calculate loss from within ConditionalGeneration class. Something like
transformers/src/transformers/models/gemma3/modeling_gemma3.py
Lines 1366 to 1383 in a861db0
| logits = outputs.logits | |
| loss = None | |
| if labels is not None: | |
| # Upcast to float if we need to compute the loss to avoid potential precision issues | |
| logits = logits.float() | |
| shift_logits = logits[..., :-1, :] | |
| shift_labels = labels[..., 1:] | |
| if attention_mask is not None: | |
| # we use the input attention mask to shift the logits and labels, because it is 2D. | |
| # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft | |
| shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) | |
| shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() | |
| shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() | |
| else: | |
| shift_logits = shift_logits.contiguous() | |
| shift_labels = shift_labels.contiguous() | |
| # Flatten the tokens | |
| loss_fct = nn.CrossEntropyLoss() |
self.loss_fn
I realize this is not a perfect solution and doesn't scale, but we are planning big refactor on multimodal models, and the vocab size issue for labels will be resolved by then. Therefore I don't want us to add more abstraction which might be deleted later on
|
|
||
| new_vocab_size = config.get_text_config().vocab_size + 10 | ||
| model.set_vocab_size(new_vocab_size) | ||
| self.assertEqual(model.language_model.get_vocab_size(), new_vocab_size) | ||
| # model's get vocab returns language model's vocab size | ||
| self.assertEqual(model.language_model.get_vocab_size(), model.get_vocab_size()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to be aligned with the linked issue, let's also test that model can take labels and return loss after resizing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call. I added the test in follow up PR #36840
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also added you as a collaborator to my fork if needed. Invite is pending
c4c64cc to
9be4728
Compare
|
@zucchini-nlp sure we can go with that approach instead . I created a new PR #36840 as had to undo the changes here. If the solution looks ok, can continue with unit tests , thank you! |
What does this PR do?
The PR updates the config of the mllama language model to update its vocab size. This is useful when vocab size of
MllamaForConditionalGenerationclass is updated when resizing embeddings usingmodel.resize_token_embeddings(embedding_size). More details included in the issue below.The PR includes one potential fix. There are several ways to do this. Suggestions welcome!
Fixes # (issue)
#36590
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
Models: