Skip to content

Commit 3078ef3

Browse files
capemoxRocketknight1
authored andcommitted
Used better way to generate seed in TF. Made tests more consistent.
1 parent a5526d0 commit 3078ef3

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

src/transformers/data/data_collator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1309,7 +1309,7 @@ def _shuffle(self, cand_indexes):
13091309
elif self.return_tensors == "tf":
13101310
import tensorflow as tf
13111311

1312-
seed = self.generator.uniform(shape=(2,), minval=0, maxval=2**31 - 1, dtype=tf.int32)
1312+
seed = self.generator.make_seeds(2)[0]
13131313
indices = tf.random.experimental.stateless_shuffle(tf.range(len(cand_indexes)), seed=seed).numpy().tolist()
13141314
return [cand_indexes[i] for i in indices]
13151315

tests/trainer/test_data_collator.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -447,18 +447,18 @@ def test_data_collator_for_whole_word_mask(self):
447447

448448
def test_data_collator_for_whole_word_mask_with_seed(self):
449449
tokenizer = BertTokenizer(self.vocab_file)
450-
features = [{"input_ids": list(range(30))}, {"input_ids": list(range(30))}]
450+
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
451451

452452
# check if seed is respected between two different DataCollatorForWholeWordMask instances
453453
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42)
454454
batch_1 = data_collator(features)
455-
self.assertEqual(batch_1["input_ids"].shape, torch.Size((2, 30)))
456-
self.assertEqual(batch_1["labels"].shape, torch.Size((2, 30)))
455+
self.assertEqual(batch_1["input_ids"].shape, torch.Size((2, 1000)))
456+
self.assertEqual(batch_1["labels"].shape, torch.Size((2, 1000)))
457457

458458
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42)
459459
batch_2 = data_collator(features)
460-
self.assertEqual(batch_2["input_ids"].shape, torch.Size((2, 30)))
461-
self.assertEqual(batch_2["labels"].shape, torch.Size((2, 30)))
460+
self.assertEqual(batch_2["input_ids"].shape, torch.Size((2, 1000)))
461+
self.assertEqual(batch_2["labels"].shape, torch.Size((2, 1000)))
462462

463463
self.assertTrue(torch.all(batch_1["input_ids"] == batch_2["input_ids"]))
464464
self.assertTrue(torch.all(batch_1["labels"] == batch_2["labels"]))
@@ -1281,27 +1281,27 @@ def test_data_collator_for_whole_word_mask(self):
12811281

12821282
def test_data_collator_for_whole_word_mask_with_seed(self):
12831283
tokenizer = BertTokenizer(self.vocab_file)
1284-
features = [{"input_ids": list(range(50))}, {"input_ids": list(range(50))}]
1284+
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
12851285

12861286
# check if seed is respected between two different DataCollatorForWholeWordMask instances
12871287
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="tf")
12881288
batch_1 = data_collator(features)
1289-
self.assertEqual(batch_1["input_ids"].shape.as_list(), [2, 50])
1290-
self.assertEqual(batch_1["labels"].shape.as_list(), [2, 50])
1289+
self.assertEqual(batch_1["input_ids"].shape.as_list(), [2, 1000])
1290+
self.assertEqual(batch_1["labels"].shape.as_list(), [2, 1000])
12911291

12921292
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="tf")
12931293
batch_2 = data_collator(features)
1294-
self.assertEqual(batch_2["input_ids"].shape.as_list(), [2, 50])
1295-
self.assertEqual(batch_2["labels"].shape.as_list(), [2, 50])
1294+
self.assertEqual(batch_2["input_ids"].shape.as_list(), [2, 1000])
1295+
self.assertEqual(batch_2["labels"].shape.as_list(), [2, 1000])
12961296

12971297
self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"]))
12981298
self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"]))
12991299

13001300
# try with different seed
13011301
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=43, return_tensors="tf")
13021302
batch_3 = data_collator(features)
1303-
self.assertEqual(batch_3["input_ids"].shape.as_list(), [2, 50])
1304-
self.assertEqual(batch_3["labels"].shape.as_list(), [2, 50])
1303+
self.assertEqual(batch_3["input_ids"].shape.as_list(), [2, 1000])
1304+
self.assertEqual(batch_3["labels"].shape.as_list(), [2, 1000])
13051305

13061306
self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"]))
13071307
self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"]))

0 commit comments

Comments
 (0)