@@ -447,18 +447,18 @@ def test_data_collator_for_whole_word_mask(self):
447447
448448 def test_data_collator_for_whole_word_mask_with_seed (self ):
449449 tokenizer = BertTokenizer (self .vocab_file )
450- features = [{"input_ids" : list (range (30 ))}, {"input_ids" : list (range (30 ))}]
450+ features = [{"input_ids" : list (range (1000 ))}, {"input_ids" : list (range (1000 ))}]
451451
452452 # check if seed is respected between two different DataCollatorForWholeWordMask instances
453453 data_collator = DataCollatorForWholeWordMask (tokenizer , seed = 42 )
454454 batch_1 = data_collator (features )
455- self .assertEqual (batch_1 ["input_ids" ].shape , torch .Size ((2 , 30 )))
456- self .assertEqual (batch_1 ["labels" ].shape , torch .Size ((2 , 30 )))
455+ self .assertEqual (batch_1 ["input_ids" ].shape , torch .Size ((2 , 1000 )))
456+ self .assertEqual (batch_1 ["labels" ].shape , torch .Size ((2 , 1000 )))
457457
458458 data_collator = DataCollatorForWholeWordMask (tokenizer , seed = 42 )
459459 batch_2 = data_collator (features )
460- self .assertEqual (batch_2 ["input_ids" ].shape , torch .Size ((2 , 30 )))
461- self .assertEqual (batch_2 ["labels" ].shape , torch .Size ((2 , 30 )))
460+ self .assertEqual (batch_2 ["input_ids" ].shape , torch .Size ((2 , 1000 )))
461+ self .assertEqual (batch_2 ["labels" ].shape , torch .Size ((2 , 1000 )))
462462
463463 self .assertTrue (torch .all (batch_1 ["input_ids" ] == batch_2 ["input_ids" ]))
464464 self .assertTrue (torch .all (batch_1 ["labels" ] == batch_2 ["labels" ]))
@@ -1281,27 +1281,27 @@ def test_data_collator_for_whole_word_mask(self):
12811281
12821282 def test_data_collator_for_whole_word_mask_with_seed (self ):
12831283 tokenizer = BertTokenizer (self .vocab_file )
1284- features = [{"input_ids" : list (range (50 ))}, {"input_ids" : list (range (50 ))}]
1284+ features = [{"input_ids" : list (range (1000 ))}, {"input_ids" : list (range (1000 ))}]
12851285
12861286 # check if seed is respected between two different DataCollatorForWholeWordMask instances
12871287 data_collator = DataCollatorForWholeWordMask (tokenizer , seed = 42 , return_tensors = "tf" )
12881288 batch_1 = data_collator (features )
1289- self .assertEqual (batch_1 ["input_ids" ].shape .as_list (), [2 , 50 ])
1290- self .assertEqual (batch_1 ["labels" ].shape .as_list (), [2 , 50 ])
1289+ self .assertEqual (batch_1 ["input_ids" ].shape .as_list (), [2 , 1000 ])
1290+ self .assertEqual (batch_1 ["labels" ].shape .as_list (), [2 , 1000 ])
12911291
12921292 data_collator = DataCollatorForWholeWordMask (tokenizer , seed = 42 , return_tensors = "tf" )
12931293 batch_2 = data_collator (features )
1294- self .assertEqual (batch_2 ["input_ids" ].shape .as_list (), [2 , 50 ])
1295- self .assertEqual (batch_2 ["labels" ].shape .as_list (), [2 , 50 ])
1294+ self .assertEqual (batch_2 ["input_ids" ].shape .as_list (), [2 , 1000 ])
1295+ self .assertEqual (batch_2 ["labels" ].shape .as_list (), [2 , 1000 ])
12961296
12971297 self .assertTrue (np .all (batch_1 ["input_ids" ] == batch_2 ["input_ids" ]))
12981298 self .assertTrue (np .all (batch_1 ["labels" ] == batch_2 ["labels" ]))
12991299
13001300 # try with different seed
13011301 data_collator = DataCollatorForWholeWordMask (tokenizer , seed = 43 , return_tensors = "tf" )
13021302 batch_3 = data_collator (features )
1303- self .assertEqual (batch_3 ["input_ids" ].shape .as_list (), [2 , 50 ])
1304- self .assertEqual (batch_3 ["labels" ].shape .as_list (), [2 , 50 ])
1303+ self .assertEqual (batch_3 ["input_ids" ].shape .as_list (), [2 , 1000 ])
1304+ self .assertEqual (batch_3 ["labels" ].shape .as_list (), [2 , 1000 ])
13051305
13061306 self .assertFalse (np .all (batch_1 ["input_ids" ] == batch_3 ["input_ids" ]))
13071307 self .assertFalse (np .all (batch_1 ["labels" ] == batch_3 ["labels" ]))
0 commit comments