Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 120 additions & 22 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,11 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
</Tip>"""

def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if self.seed and self.generator is None:
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
# If no seed supplied, we will use the global RNG
self.create_rng()

if isinstance(examples[0], Mapping):
input_ids = [e["input_ids"] for e in examples]
else:
Expand Down Expand Up @@ -1223,6 +1228,11 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
import tensorflow as tf

if self.seed and self.generator is None:
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
# If no seed supplied, we will use the global RNG
self.create_rng()

if isinstance(examples[0], Mapping):
input_ids = [e["input_ids"] for e in examples]
else:
Expand Down Expand Up @@ -1251,6 +1261,11 @@ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict
return {"input_ids": inputs, "labels": labels}

def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if self.seed and self.generator is None:
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
# If no seed supplied, we will use the global RNG
self.create_rng()

if isinstance(examples[0], Mapping):
input_ids = [e["input_ids"] for e in examples]
else:
Expand Down Expand Up @@ -1278,6 +1293,30 @@ def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask)
return {"input_ids": inputs, "labels": labels}

def _shuffle(self, cand_indexes):
# if no seed, just use random's shuffle
if self.seed is None:
random.shuffle(cand_indexes)
return cand_indexes

# if seed is provided, use the generator to shuffle
if self.return_tensors == "pt":
import torch

indices = torch.randperm(len(cand_indexes), generator=self.generator)
return [cand_indexes[i] for i in indices]

elif self.return_tensors == "tf":
import tensorflow as tf

seed = self.generator.make_seeds(2)[0]
indices = tf.random.experimental.stateless_shuffle(tf.range(len(cand_indexes)), seed=seed).numpy().tolist()
return [cand_indexes[i] for i in indices]

elif self.return_tensors == "np":
self.generator.shuffle(cand_indexes)
return cand_indexes

def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
"""
Get 0/1 labels for masked tokens with whole word mask proxy
Expand All @@ -1298,7 +1337,7 @@ def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
else:
cand_indexes.append([i])

random.shuffle(cand_indexes)
cand_indexes = self._shuffle(cand_indexes)
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
masked_lms = []
covered_indexes = set()
Expand Down Expand Up @@ -1346,16 +1385,32 @@ def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
masked_indices = probability_matrix.bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = (
torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob), generator=self.generator).bool()
& masked_indices
)
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
return inputs, labels

remaining_prob = 1 - self.mask_replace_prob
# scaling the random_replace_prob to the remaining probability for example if
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
random_replace_prob_scaled = self.random_replace_prob / remaining_prob

# random_replacement_prob% of the time, we replace masked input tokens with random word
indices_random = (
torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled), generator=self.generator).bool()
& masked_indices
& ~indices_replaced
)
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long, generator=self.generator)
inputs[indices_random] = random_words[indices_random]

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
# The rest of the time ((1-random_replacement_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged
return inputs, labels

def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
Expand Down Expand Up @@ -1387,17 +1442,35 @@ def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
labels = tf.where(masked_indices, inputs, -100)

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob, self.generator) & masked_indices

inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs)

# 10% of the time, we replace masked input tokens with random word
indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
return inputs, labels

remaining_prob = 1 - self.mask_replace_prob
# scaling the random_replace_prob to the remaining probability for example if
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
random_replace_prob_scaled = self.random_replace_prob / remaining_prob

# random_replace_prob% of the time, we replace masked input tokens with random word
indices_random = (
self.tf_bernoulli(input_shape, random_replace_prob_scaled, self.generator)
& masked_indices
& ~indices_replaced
)

if self.generator:
random_words = self.generator.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
else:
random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)

inputs = tf.where(indices_random, random_words, inputs)

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
# The rest of the time ((1-mask_replace_prob-random_replace_prob)% of the time) we keep the masked input tokens unchanged
return inputs, labels

def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
Expand Down Expand Up @@ -1425,19 +1498,44 @@ def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:

labels[~masked_indices] = -100 # We only compute loss on masked tokens

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
# mask_replacement_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
if self.generator:
indices_replaced = (
self.generator.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
)
else:
indices_replaced = (
np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
)
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

# 10% of the time, we replace masked input tokens with random word
# indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
indices_random = (
np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
)
random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
return inputs, labels

remaining_prob = 1 - self.mask_replace_prob
# scaling the random_replace_prob to the remaining probability for example if
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
random_replace_prob_scaled = self.random_replace_prob / remaining_prob

if self.generator:
indices_random = (
self.generator.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
& masked_indices
& ~indices_replaced
)
random_words = self.generator.integers(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
else:
indices_random = (
np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
& masked_indices
& ~indices_replaced
)
random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)

inputs[indices_random] = random_words[indices_random]

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
# The rest of the time ((1-mask_replace_prob-random_replace_prob)% of the time) we keep the masked input tokens unchanged
return inputs, labels


Expand Down
133 changes: 133 additions & 0 deletions tests/trainer/test_data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,86 @@ def test_data_collator_for_whole_word_mask(self):
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))

def test_data_collator_for_whole_word_mask_with_seed(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]

# check if seed is respected between two different DataCollatorForWholeWordMask instances
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42)
batch_1 = data_collator(features)
self.assertEqual(batch_1["input_ids"].shape, torch.Size((2, 1000)))
self.assertEqual(batch_1["labels"].shape, torch.Size((2, 1000)))

data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42)
batch_2 = data_collator(features)
self.assertEqual(batch_2["input_ids"].shape, torch.Size((2, 1000)))
self.assertEqual(batch_2["labels"].shape, torch.Size((2, 1000)))

self.assertTrue(torch.all(batch_1["input_ids"] == batch_2["input_ids"]))
self.assertTrue(torch.all(batch_1["labels"] == batch_2["labels"]))

# check if seed is respected in multiple workers situation
features = [{"input_ids": list(range(1000))} for _ in range(10)]
dataloader = torch.utils.data.DataLoader(
features,
batch_size=2,
num_workers=2,
generator=torch.Generator().manual_seed(42),
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42),
)

batch_3_input_ids = []
batch_3_labels = []
for batch in dataloader:
batch_3_input_ids.append(batch["input_ids"])
batch_3_labels.append(batch["labels"])

batch_3_input_ids = torch.stack(batch_3_input_ids)
batch_3_labels = torch.stack(batch_3_labels)
self.assertEqual(batch_3_input_ids.shape, torch.Size((5, 2, 1000)))
self.assertEqual(batch_3_labels.shape, torch.Size((5, 2, 1000)))

dataloader = torch.utils.data.DataLoader(
features,
batch_size=2,
num_workers=2,
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42),
)

batch_4_input_ids = []
batch_4_labels = []
for batch in dataloader:
batch_4_input_ids.append(batch["input_ids"])
batch_4_labels.append(batch["labels"])
batch_4_input_ids = torch.stack(batch_4_input_ids)
batch_4_labels = torch.stack(batch_4_labels)
self.assertEqual(batch_4_input_ids.shape, torch.Size((5, 2, 1000)))
self.assertEqual(batch_4_labels.shape, torch.Size((5, 2, 1000)))

self.assertTrue(torch.all(batch_3_input_ids == batch_4_input_ids))
self.assertTrue(torch.all(batch_3_labels == batch_4_labels))

# try with different seed
dataloader = torch.utils.data.DataLoader(
features,
batch_size=2,
num_workers=2,
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=43),
)

batch_5_input_ids = []
batch_5_labels = []
for batch in dataloader:
batch_5_input_ids.append(batch["input_ids"])
batch_5_labels.append(batch["labels"])
batch_5_input_ids = torch.stack(batch_5_input_ids)
batch_5_labels = torch.stack(batch_5_labels)
self.assertEqual(batch_5_input_ids.shape, torch.Size((5, 2, 1000)))
self.assertEqual(batch_5_labels.shape, torch.Size((5, 2, 1000)))

self.assertFalse(torch.all(batch_3_input_ids == batch_5_input_ids))
self.assertFalse(torch.all(batch_3_labels == batch_5_labels))

def test_plm(self):
tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
Expand Down Expand Up @@ -1199,6 +1279,33 @@ def test_data_collator_for_whole_word_mask(self):
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])

def test_data_collator_for_whole_word_mask_with_seed(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]

# check if seed is respected between two different DataCollatorForWholeWordMask instances
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="tf")
batch_1 = data_collator(features)
self.assertEqual(batch_1["input_ids"].shape.as_list(), [2, 1000])
self.assertEqual(batch_1["labels"].shape.as_list(), [2, 1000])

data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="tf")
batch_2 = data_collator(features)
self.assertEqual(batch_2["input_ids"].shape.as_list(), [2, 1000])
self.assertEqual(batch_2["labels"].shape.as_list(), [2, 1000])

self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"]))
self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"]))

# try with different seed
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=43, return_tensors="tf")
batch_3 = data_collator(features)
self.assertEqual(batch_3["input_ids"].shape.as_list(), [2, 1000])
self.assertEqual(batch_3["labels"].shape.as_list(), [2, 1000])

self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"]))
self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"]))

def test_plm(self):
tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
Expand Down Expand Up @@ -1920,6 +2027,32 @@ def test_data_collator_for_whole_word_mask(self):
self.assertEqual(batch["input_ids"].shape, (2, 10))
self.assertEqual(batch["labels"].shape, (2, 10))

def test_data_collator_for_whole_word_mask_with_seed(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]

# check if seed is respected between two different DataCollatorForWholeWordMask instances
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np")
batch_1 = data_collator(features)
self.assertEqual(batch_1["input_ids"].shape, (2, 1000))
self.assertEqual(batch_1["labels"].shape, (2, 1000))

data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np")
batch_2 = data_collator(features)
self.assertEqual(batch_2["input_ids"].shape, (2, 1000))
self.assertEqual(batch_2["labels"].shape, (2, 1000))

self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"]))
self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"]))

data_collator = DataCollatorForWholeWordMask(tokenizer, seed=43, return_tensors="np")
batch_3 = data_collator(features)
self.assertEqual(batch_3["input_ids"].shape, (2, 1000))
self.assertEqual(batch_3["labels"].shape, (2, 1000))

self.assertFalse(np.all(batch_1["input_ids"] == batch_3["input_ids"]))
self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"]))

def test_plm(self):
tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
Expand Down