Skip to content

Commit 5c6d6be

Browse files
authored
[PEFT] Fix the general test for prefix tuning (#42185)
fix
1 parent 80134e6 commit 5c6d6be

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

tests/utils/test_modeling_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,6 +1818,8 @@ def test_cache_when_needed_at_train_time(self):
18181818
self.assertTrue(model.training)
18191819

18201820
# We can also disable the cache to skip a few operations, if the training loop doesn't need cache
1821+
# NOTE: after #41900, we need to pass the correct attention mask size
1822+
model_inputs["attention_mask"] = model_inputs["attention_mask"][:, :-num_virtual_tokens]
18211823
model_outputs = model(**model_inputs, use_cache=False)
18221824
self.assertIsNone(model_outputs.past_key_values)
18231825
self.assertTrue(model.training)

0 commit comments

Comments
 (0)