Skip to content

Commit 90e2df5

Browse files
authored
fix: loss computation after embeddings resize - mllama (#36840)
* move loss to generation class Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * code cleanup Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * test for resize and loss computation Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix tests Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix:test for resize and loss Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix resize embedding mllama test Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * review changes Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> --------- Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
1 parent 4542b8f commit 90e2df5

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

src/transformers/models/mllama/modeling_mllama.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2056,6 +2056,7 @@ def forward(
20562056
return_dict: Optional[bool] = None,
20572057
cache_position: Optional[torch.LongTensor] = None,
20582058
logits_to_keep: Union[int, torch.Tensor] = 0,
2059+
**loss_kwargs,
20592060
) -> Union[Tuple, CausalLMOutputWithPast]:
20602061
r"""
20612062
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -2157,15 +2158,31 @@ def forward(
21572158
past_key_values=past_key_values,
21582159
use_cache=use_cache,
21592160
inputs_embeds=inputs_embeds,
2160-
labels=labels,
21612161
output_hidden_states=output_hidden_states,
21622162
output_attentions=output_attentions,
21632163
return_dict=return_dict,
21642164
cache_position=cache_position,
21652165
logits_to_keep=logits_to_keep,
2166+
**loss_kwargs,
21662167
)
21672168

2168-
return outputs
2169+
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
2170+
loss = None
2171+
logits = outputs[0]
2172+
2173+
if labels is not None:
2174+
loss = self.loss_function(logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs)
2175+
2176+
if not return_dict:
2177+
return (loss,) + outputs if loss is not None else outputs
2178+
2179+
return CausalLMOutputWithPast(
2180+
loss=loss,
2181+
logits=outputs.logits,
2182+
past_key_values=outputs.past_key_values,
2183+
hidden_states=outputs.hidden_states,
2184+
attentions=outputs.attentions,
2185+
)
21692186

21702187
def prepare_inputs_for_generation(
21712188
self,

tests/models/mllama/test_modeling_mllama.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,24 @@ def test_inputs_embeds_matches_input_ids(self):
321321
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
322322
torch.testing.assert_close(out_embeds, out_ids)
323323

324+
def test_resize_embeddings_results_in_successful_loss(self):
325+
# resizing embeddings should result in successful loss computation
326+
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
327+
328+
for model_class in self.all_model_classes:
329+
model = model_class(config)
330+
model_vocab_size = config.get_text_config().vocab_size
331+
inputs = self._prepare_for_class(inputs, model_class, return_labels=True)
332+
# Resize embeddings and call forward
333+
model.resize_token_embeddings(model_vocab_size + 10)
334+
output = model(
335+
input_ids=inputs["input_ids"],
336+
attention_mask=inputs["attention_mask"],
337+
labels=inputs["labels"],
338+
return_dict=True,
339+
)
340+
self.assertTrue("loss" in output)
341+
324342
def _check_attentions_for_generate(
325343
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
326344
):

0 commit comments

Comments
 (0)