@@ -1193,6 +1193,11 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
11931193 </Tip>"""
11941194
11951195 def torch_call (self , examples : List [Union [List [int ], Any , Dict [str , Any ]]]) -> Dict [str , Any ]:
1196+ if self .seed and self .generator is None :
1197+ # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
1198+ # If no seed supplied, we will use the global RNG
1199+ self .create_rng ()
1200+
11961201 if isinstance (examples [0 ], Mapping ):
11971202 input_ids = [e ["input_ids" ] for e in examples ]
11981203 else :
@@ -1223,6 +1228,11 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
12231228 def tf_call (self , examples : List [Union [List [int ], Any , Dict [str , Any ]]]) -> Dict [str , Any ]:
12241229 import tensorflow as tf
12251230
1231+ if self .seed and self .generator is None :
1232+ # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
1233+ # If no seed supplied, we will use the global RNG
1234+ self .create_rng ()
1235+
12261236 if isinstance (examples [0 ], Mapping ):
12271237 input_ids = [e ["input_ids" ] for e in examples ]
12281238 else :
@@ -1251,6 +1261,11 @@ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict
12511261 return {"input_ids" : inputs , "labels" : labels }
12521262
12531263 def numpy_call (self , examples : List [Union [List [int ], Any , Dict [str , Any ]]]) -> Dict [str , Any ]:
1264+ if self .seed and self .generator is None :
1265+ # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
1266+ # If no seed supplied, we will use the global RNG
1267+ self .create_rng ()
1268+
12541269 if isinstance (examples [0 ], Mapping ):
12551270 input_ids = [e ["input_ids" ] for e in examples ]
12561271 else :
@@ -1278,6 +1293,30 @@ def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
12781293 inputs , labels = self .numpy_mask_tokens (batch_input , batch_mask )
12791294 return {"input_ids" : inputs , "labels" : labels }
12801295
1296+ def _shuffle (self , cand_indexes ):
1297+ # if no seed, just use random's shuffle
1298+ if self .seed is None :
1299+ random .shuffle (cand_indexes )
1300+ return cand_indexes
1301+
1302+ # if seed is provided, use the generator to shuffle
1303+ if self .return_tensors == "pt" :
1304+ import torch
1305+
1306+ indices = torch .randperm (len (cand_indexes ), generator = self .generator )
1307+ return [cand_indexes [i ] for i in indices ]
1308+
1309+ elif self .return_tensors == "tf" :
1310+ import tensorflow as tf
1311+
1312+ seed = self .generator .make_seeds (2 )[0 ]
1313+ indices = tf .random .experimental .stateless_shuffle (tf .range (len (cand_indexes )), seed = seed ).numpy ().tolist ()
1314+ return [cand_indexes [i ] for i in indices ]
1315+
1316+ elif self .return_tensors == "np" :
1317+ self .generator .shuffle (cand_indexes )
1318+ return cand_indexes
1319+
12811320 def _whole_word_mask (self , input_tokens : List [str ], max_predictions = 512 ):
12821321 """
12831322 Get 0/1 labels for masked tokens with whole word mask proxy
@@ -1298,7 +1337,7 @@ def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
12981337 else :
12991338 cand_indexes .append ([i ])
13001339
1301- random . shuffle (cand_indexes )
1340+ cand_indexes = self . _shuffle (cand_indexes )
13021341 num_to_predict = min (max_predictions , max (1 , int (round (len (input_tokens ) * self .mlm_probability ))))
13031342 masked_lms = []
13041343 covered_indexes = set ()
@@ -1346,16 +1385,32 @@ def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
13461385 masked_indices = probability_matrix .bool ()
13471386 labels [~ masked_indices ] = - 100 # We only compute loss on masked tokens
13481387
1349- # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1350- indices_replaced = torch .bernoulli (torch .full (labels .shape , 0.8 )).bool () & masked_indices
1388+ # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1389+ indices_replaced = (
1390+ torch .bernoulli (torch .full (labels .shape , self .mask_replace_prob ), generator = self .generator ).bool ()
1391+ & masked_indices
1392+ )
13511393 inputs [indices_replaced ] = self .tokenizer .convert_tokens_to_ids (self .tokenizer .mask_token )
13521394
1353- # 10% of the time, we replace masked input tokens with random word
1354- indices_random = torch .bernoulli (torch .full (labels .shape , 0.5 )).bool () & masked_indices & ~ indices_replaced
1355- random_words = torch .randint (len (self .tokenizer ), labels .shape , dtype = torch .long )
1395+ if self .mask_replace_prob == 1 or self .random_replace_prob == 0 :
1396+ return inputs , labels
1397+
1398+ remaining_prob = 1 - self .mask_replace_prob
1399+ # scaling the random_replace_prob to the remaining probability for example if
1400+ # mask_replace_prob = 0.8 and random_replace_prob = 0.1,
1401+ # then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
1402+ random_replace_prob_scaled = self .random_replace_prob / remaining_prob
1403+
1404+ # random_replacement_prob% of the time, we replace masked input tokens with random word
1405+ indices_random = (
1406+ torch .bernoulli (torch .full (labels .shape , random_replace_prob_scaled ), generator = self .generator ).bool ()
1407+ & masked_indices
1408+ & ~ indices_replaced
1409+ )
1410+ random_words = torch .randint (len (self .tokenizer ), labels .shape , dtype = torch .long , generator = self .generator )
13561411 inputs [indices_random ] = random_words [indices_random ]
13571412
1358- # The rest of the time (10 % of the time) we keep the masked input tokens unchanged
1413+ # The rest of the time ((1-random_replacement_prob-mask_replace_prob) % of the time) we keep the masked input tokens unchanged
13591414 return inputs , labels
13601415
13611416 def tf_mask_tokens (self , inputs : Any , mask_labels : Any ) -> Tuple [Any , Any ]:
@@ -1387,17 +1442,35 @@ def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
13871442 # Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
13881443 labels = tf .where (masked_indices , inputs , - 100 )
13891444
1390- # 80 % of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1391- indices_replaced = self .tf_bernoulli (input_shape , 0.8 ) & masked_indices
1445+ # mask_replace_prob % of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1446+ indices_replaced = self .tf_bernoulli (input_shape , self . mask_replace_prob , self . generator ) & masked_indices
13921447
13931448 inputs = tf .where (indices_replaced , self .tokenizer .mask_token_id , inputs )
13941449
1395- # 10% of the time, we replace masked input tokens with random word
1396- indices_random = self .tf_bernoulli (input_shape , 0.5 ) & masked_indices & ~ indices_replaced
1397- random_words = tf .random .uniform (input_shape , maxval = len (self .tokenizer ), dtype = tf .int64 )
1450+ if self .mask_replace_prob == 1 or self .random_replace_prob == 0 :
1451+ return inputs , labels
1452+
1453+ remaining_prob = 1 - self .mask_replace_prob
1454+ # scaling the random_replace_prob to the remaining probability for example if
1455+ # mask_replace_prob = 0.8 and random_replace_prob = 0.1,
1456+ # then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
1457+ random_replace_prob_scaled = self .random_replace_prob / remaining_prob
1458+
1459+ # random_replace_prob% of the time, we replace masked input tokens with random word
1460+ indices_random = (
1461+ self .tf_bernoulli (input_shape , random_replace_prob_scaled , self .generator )
1462+ & masked_indices
1463+ & ~ indices_replaced
1464+ )
1465+
1466+ if self .generator :
1467+ random_words = self .generator .uniform (input_shape , maxval = len (self .tokenizer ), dtype = tf .int64 )
1468+ else :
1469+ random_words = tf .random .uniform (input_shape , maxval = len (self .tokenizer ), dtype = tf .int64 )
1470+
13981471 inputs = tf .where (indices_random , random_words , inputs )
13991472
1400- # The rest of the time (10 % of the time) we keep the masked input tokens unchanged
1473+ # The rest of the time ((1-mask_replace_prob-random_replace_prob) % of the time) we keep the masked input tokens unchanged
14011474 return inputs , labels
14021475
14031476 def numpy_mask_tokens (self , inputs : Any , mask_labels : Any ) -> Tuple [Any , Any ]:
@@ -1425,19 +1498,44 @@ def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
14251498
14261499 labels [~ masked_indices ] = - 100 # We only compute loss on masked tokens
14271500
1428- # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1429- indices_replaced = np .random .binomial (1 , 0.8 , size = labels .shape ).astype (bool ) & masked_indices
1501+ # mask_replacement_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1502+ if self .generator :
1503+ indices_replaced = (
1504+ self .generator .binomial (1 , self .mask_replace_prob , size = labels .shape ).astype (bool ) & masked_indices
1505+ )
1506+ else :
1507+ indices_replaced = (
1508+ np .random .binomial (1 , self .mask_replace_prob , size = labels .shape ).astype (bool ) & masked_indices
1509+ )
14301510 inputs [indices_replaced ] = self .tokenizer .convert_tokens_to_ids (self .tokenizer .mask_token )
14311511
1432- # 10% of the time, we replace masked input tokens with random word
1433- # indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
1434- indices_random = (
1435- np .random .binomial (1 , 0.5 , size = labels .shape ).astype (bool ) & masked_indices & ~ indices_replaced
1436- )
1437- random_words = np .random .randint (low = 0 , high = len (self .tokenizer ), size = labels .shape , dtype = np .int64 )
1512+ if self .mask_replace_prob == 1 or self .random_replace_prob == 0 :
1513+ return inputs , labels
1514+
1515+ remaining_prob = 1 - self .mask_replace_prob
1516+ # scaling the random_replace_prob to the remaining probability for example if
1517+ # mask_replace_prob = 0.8 and random_replace_prob = 0.1,
1518+ # then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
1519+ random_replace_prob_scaled = self .random_replace_prob / remaining_prob
1520+
1521+ if self .generator :
1522+ indices_random = (
1523+ self .generator .binomial (1 , random_replace_prob_scaled , size = labels .shape ).astype (bool )
1524+ & masked_indices
1525+ & ~ indices_replaced
1526+ )
1527+ random_words = self .generator .integers (low = 0 , high = len (self .tokenizer ), size = labels .shape , dtype = np .int64 )
1528+ else :
1529+ indices_random = (
1530+ np .random .binomial (1 , random_replace_prob_scaled , size = labels .shape ).astype (bool )
1531+ & masked_indices
1532+ & ~ indices_replaced
1533+ )
1534+ random_words = np .random .randint (low = 0 , high = len (self .tokenizer ), size = labels .shape , dtype = np .int64 )
1535+
14381536 inputs [indices_random ] = random_words [indices_random ]
14391537
1440- # The rest of the time (10 % of the time) we keep the masked input tokens unchanged
1538+ # The rest of the time ((1-mask_replace_prob-random_replace_prob) % of the time) we keep the masked input tokens unchanged
14411539 return inputs , labels
14421540
14431541
0 commit comments