Skip to content

Commit e082712

Browse files
capemoxzucchini-nlp
authored andcommitted
Added support for seed in DataCollatorForWholeWordMask (huggingface#36903)
* Added support for seed in `DataCollatorForWholeWordMask`, and also wrote tests. Also fixed bugs where the code hardcoded values for mask replacement probability and random replacement probability, instead of using the values passed by the user. * formatting issues * Used better way to generate seed in TF. Made tests more consistent.
1 parent c8d545a commit e082712

File tree

2 files changed

+253
-22
lines changed

2 files changed

+253
-22
lines changed

src/transformers/data/data_collator.py

Lines changed: 120 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,11 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
11931193
</Tip>"""
11941194

11951195
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
1196+
if self.seed and self.generator is None:
1197+
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
1198+
# If no seed supplied, we will use the global RNG
1199+
self.create_rng()
1200+
11961201
if isinstance(examples[0], Mapping):
11971202
input_ids = [e["input_ids"] for e in examples]
11981203
else:
@@ -1223,6 +1228,11 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
12231228
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
12241229
import tensorflow as tf
12251230

1231+
if self.seed and self.generator is None:
1232+
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
1233+
# If no seed supplied, we will use the global RNG
1234+
self.create_rng()
1235+
12261236
if isinstance(examples[0], Mapping):
12271237
input_ids = [e["input_ids"] for e in examples]
12281238
else:
@@ -1251,6 +1261,11 @@ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict
12511261
return {"input_ids": inputs, "labels": labels}
12521262

12531263
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
1264+
if self.seed and self.generator is None:
1265+
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
1266+
# If no seed supplied, we will use the global RNG
1267+
self.create_rng()
1268+
12541269
if isinstance(examples[0], Mapping):
12551270
input_ids = [e["input_ids"] for e in examples]
12561271
else:
@@ -1278,6 +1293,30 @@ def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
12781293
inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask)
12791294
return {"input_ids": inputs, "labels": labels}
12801295

1296+
def _shuffle(self, cand_indexes):
1297+
# if no seed, just use random's shuffle
1298+
if self.seed is None:
1299+
random.shuffle(cand_indexes)
1300+
return cand_indexes
1301+
1302+
# if seed is provided, use the generator to shuffle
1303+
if self.return_tensors == "pt":
1304+
import torch
1305+
1306+
indices = torch.randperm(len(cand_indexes), generator=self.generator)
1307+
return [cand_indexes[i] for i in indices]
1308+
1309+
elif self.return_tensors == "tf":
1310+
import tensorflow as tf
1311+
1312+
seed = self.generator.make_seeds(2)[0]
1313+
indices = tf.random.experimental.stateless_shuffle(tf.range(len(cand_indexes)), seed=seed).numpy().tolist()
1314+
return [cand_indexes[i] for i in indices]
1315+
1316+
elif self.return_tensors == "np":
1317+
self.generator.shuffle(cand_indexes)
1318+
return cand_indexes
1319+
12811320
def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
12821321
"""
12831322
Get 0/1 labels for masked tokens with whole word mask proxy
@@ -1298,7 +1337,7 @@ def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
12981337
else:
12991338
cand_indexes.append([i])
13001339

1301-
random.shuffle(cand_indexes)
1340+
cand_indexes = self._shuffle(cand_indexes)
13021341
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
13031342
masked_lms = []
13041343
covered_indexes = set()
@@ -1346,16 +1385,32 @@ def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
13461385
masked_indices = probability_matrix.bool()
13471386
labels[~masked_indices] = -100 # We only compute loss on masked tokens
13481387

1349-
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1350-
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
1388+
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1389+
indices_replaced = (
1390+
torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob), generator=self.generator).bool()
1391+
& masked_indices
1392+
)
13511393
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
13521394

1353-
# 10% of the time, we replace masked input tokens with random word
1354-
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
1355-
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
1395+
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
1396+
return inputs, labels
1397+
1398+
remaining_prob = 1 - self.mask_replace_prob
1399+
# scaling the random_replace_prob to the remaining probability for example if
1400+
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
1401+
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
1402+
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
1403+
1404+
# random_replacement_prob% of the time, we replace masked input tokens with random word
1405+
indices_random = (
1406+
torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled), generator=self.generator).bool()
1407+
& masked_indices
1408+
& ~indices_replaced
1409+
)
1410+
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long, generator=self.generator)
13561411
inputs[indices_random] = random_words[indices_random]
13571412

1358-
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
1413+
# The rest of the time ((1-random_replacement_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged
13591414
return inputs, labels
13601415

13611416
def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
@@ -1387,17 +1442,35 @@ def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
13871442
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
13881443
labels = tf.where(masked_indices, inputs, -100)
13891444

1390-
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1391-
indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
1445+
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1446+
indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob, self.generator) & masked_indices
13921447

13931448
inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs)
13941449

1395-
# 10% of the time, we replace masked input tokens with random word
1396-
indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
1397-
random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
1450+
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
1451+
return inputs, labels
1452+
1453+
remaining_prob = 1 - self.mask_replace_prob
1454+
# scaling the random_replace_prob to the remaining probability for example if
1455+
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
1456+
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
1457+
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
1458+
1459+
# random_replace_prob% of the time, we replace masked input tokens with random word
1460+
indices_random = (
1461+
self.tf_bernoulli(input_shape, random_replace_prob_scaled, self.generator)
1462+
& masked_indices
1463+
& ~indices_replaced
1464+
)
1465+
1466+
if self.generator:
1467+
random_words = self.generator.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
1468+
else:
1469+
random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
1470+
13981471
inputs = tf.where(indices_random, random_words, inputs)
13991472

1400-
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
1473+
# The rest of the time ((1-mask_replace_prob-random_replace_prob)% of the time) we keep the masked input tokens unchanged
14011474
return inputs, labels
14021475

14031476
def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
@@ -1425,19 +1498,44 @@ def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
14251498

14261499
labels[~masked_indices] = -100 # We only compute loss on masked tokens
14271500

1428-
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1429-
indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
1501+
# mask_replacement_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1502+
if self.generator:
1503+
indices_replaced = (
1504+
self.generator.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
1505+
)
1506+
else:
1507+
indices_replaced = (
1508+
np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
1509+
)
14301510
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
14311511

1432-
# 10% of the time, we replace masked input tokens with random word
1433-
# indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
1434-
indices_random = (
1435-
np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
1436-
)
1437-
random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
1512+
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
1513+
return inputs, labels
1514+
1515+
remaining_prob = 1 - self.mask_replace_prob
1516+
# scaling the random_replace_prob to the remaining probability for example if
1517+
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
1518+
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
1519+
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
1520+
1521+
if self.generator:
1522+
indices_random = (
1523+
self.generator.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
1524+
& masked_indices
1525+
& ~indices_replaced
1526+
)
1527+
random_words = self.generator.integers(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
1528+
else:
1529+
indices_random = (
1530+
np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
1531+
& masked_indices
1532+
& ~indices_replaced
1533+
)
1534+
random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
1535+
14381536
inputs[indices_random] = random_words[indices_random]
14391537

1440-
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
1538+
# The rest of the time ((1-mask_replace_prob-random_replace_prob)% of the time) we keep the masked input tokens unchanged
14411539
return inputs, labels
14421540

14431541

tests/trainer/test_data_collator.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,86 @@ def test_data_collator_for_whole_word_mask(self):
445445
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
446446
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
447447

448+
def test_data_collator_for_whole_word_mask_with_seed(self):
449+
tokenizer = BertTokenizer(self.vocab_file)
450+
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
451+
452+
# check if seed is respected between two different DataCollatorForWholeWordMask instances
453+
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42)
454+
batch_1 = data_collator(features)
455+
self.assertEqual(batch_1["input_ids"].shape, torch.Size((2, 1000)))
456+
self.assertEqual(batch_1["labels"].shape, torch.Size((2, 1000)))
457+
458+
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42)
459+
batch_2 = data_collator(features)
460+
self.assertEqual(batch_2["input_ids"].shape, torch.Size((2, 1000)))
461+
self.assertEqual(batch_2["labels"].shape, torch.Size((2, 1000)))
462+
463+
self.assertTrue(torch.all(batch_1["input_ids"] == batch_2["input_ids"]))
464+
self.assertTrue(torch.all(batch_1["labels"] == batch_2["labels"]))
465+
466+
# check if seed is respected in multiple workers situation
467+
features = [{"input_ids": list(range(1000))} for _ in range(10)]
468+
dataloader = torch.utils.data.DataLoader(
469+
features,
470+
batch_size=2,
471+
num_workers=2,
472+
generator=torch.Generator().manual_seed(42),
473+
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42),
474+
)
475+
476+
batch_3_input_ids = []
477+
batch_3_labels = []
478+
for batch in dataloader:
479+
batch_3_input_ids.append(batch["input_ids"])
480+
batch_3_labels.append(batch["labels"])
481+
482+
batch_3_input_ids = torch.stack(batch_3_input_ids)
483+
batch_3_labels = torch.stack(batch_3_labels)
484+
self.assertEqual(batch_3_input_ids.shape, torch.Size((5, 2, 1000)))
485+
self.assertEqual(batch_3_labels.shape, torch.Size((5, 2, 1000)))
486+
487+
dataloader = torch.utils.data.DataLoader(
488+
features,
489+
batch_size=2,
490+
num_workers=2,
491+
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42),
492+
)
493+
494+
batch_4_input_ids = []
495+
batch_4_labels = []
496+
for batch in dataloader:
497+
batch_4_input_ids.append(batch["input_ids"])
498+
batch_4_labels.append(batch["labels"])
499+
batch_4_input_ids = torch.stack(batch_4_input_ids)
500+
batch_4_labels = torch.stack(batch_4_labels)
501+
self.assertEqual(batch_4_input_ids.shape, torch.Size((5, 2, 1000)))
502+
self.assertEqual(batch_4_labels.shape, torch.Size((5, 2, 1000)))
503+
504+
self.assertTrue(torch.all(batch_3_input_ids == batch_4_input_ids))
505+
self.assertTrue(torch.all(batch_3_labels == batch_4_labels))
506+
507+
# try with different seed
508+
dataloader = torch.utils.data.DataLoader(
509+
features,
510+
batch_size=2,
511+
num_workers=2,
512+
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=43),
513+
)
514+
515+
batch_5_input_ids = []
516+
batch_5_labels = []
517+
for batch in dataloader:
518+
batch_5_input_ids.append(batch["input_ids"])
519+
batch_5_labels.append(batch["labels"])
520+
batch_5_input_ids = torch.stack(batch_5_input_ids)
521+
batch_5_labels = torch.stack(batch_5_labels)
522+
self.assertEqual(batch_5_input_ids.shape, torch.Size((5, 2, 1000)))
523+
self.assertEqual(batch_5_labels.shape, torch.Size((5, 2, 1000)))
524+
525+
self.assertFalse(torch.all(batch_3_input_ids == batch_5_input_ids))
526+
self.assertFalse(torch.all(batch_3_labels == batch_5_labels))
527+
448528
def test_plm(self):
449529
tokenizer = BertTokenizer(self.vocab_file)
450530
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
@@ -1199,6 +1279,33 @@ def test_data_collator_for_whole_word_mask(self):
11991279
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
12001280
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])
12011281

1282+
def test_data_collator_for_whole_word_mask_with_seed(self):
1283+
tokenizer = BertTokenizer(self.vocab_file)
1284+
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
1285+
1286+
# check if seed is respected between two different DataCollatorForWholeWordMask instances
1287+
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="tf")
1288+
batch_1 = data_collator(features)
1289+
self.assertEqual(batch_1["input_ids"].shape.as_list(), [2, 1000])
1290+
self.assertEqual(batch_1["labels"].shape.as_list(), [2, 1000])
1291+
1292+
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="tf")
1293+
batch_2 = data_collator(features)
1294+
self.assertEqual(batch_2["input_ids"].shape.as_list(), [2, 1000])
1295+
self.assertEqual(batch_2["labels"].shape.as_list(), [2, 1000])
1296+
1297+
self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"]))
1298+
self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"]))
1299+
1300+
# try with different seed
1301+
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=43, return_tensors="tf")
1302+
batch_3 = data_collator(features)
1303+
self.assertEqual(batch_3["input_ids"].shape.as_list(), [2, 1000])
1304+
self.assertEqual(batch_3["labels"].shape.as_list(), [2, 1000])
1305+
1306+
self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"]))
1307+
self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"]))
1308+
12021309
def test_plm(self):
12031310
tokenizer = BertTokenizer(self.vocab_file)
12041311
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
@@ -1920,6 +2027,32 @@ def test_data_collator_for_whole_word_mask(self):
19202027
self.assertEqual(batch["input_ids"].shape, (2, 10))
19212028
self.assertEqual(batch["labels"].shape, (2, 10))
19222029

2030+
def test_data_collator_for_whole_word_mask_with_seed(self):
2031+
tokenizer = BertTokenizer(self.vocab_file)
2032+
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
2033+
2034+
# check if seed is respected between two different DataCollatorForWholeWordMask instances
2035+
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np")
2036+
batch_1 = data_collator(features)
2037+
self.assertEqual(batch_1["input_ids"].shape, (2, 1000))
2038+
self.assertEqual(batch_1["labels"].shape, (2, 1000))
2039+
2040+
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np")
2041+
batch_2 = data_collator(features)
2042+
self.assertEqual(batch_2["input_ids"].shape, (2, 1000))
2043+
self.assertEqual(batch_2["labels"].shape, (2, 1000))
2044+
2045+
self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"]))
2046+
self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"]))
2047+
2048+
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=43, return_tensors="np")
2049+
batch_3 = data_collator(features)
2050+
self.assertEqual(batch_3["input_ids"].shape, (2, 1000))
2051+
self.assertEqual(batch_3["labels"].shape, (2, 1000))
2052+
2053+
self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"]))
2054+
self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"]))
2055+
19232056
def test_plm(self):
19242057
tokenizer = BertTokenizer(self.vocab_file)
19252058
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]

0 commit comments

Comments
 (0)