-
Notifications
You must be signed in to change notification settings - Fork 13
/
utils_data.py
94 lines (79 loc) · 4.22 KB
/
utils_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers.data.data_collator import *
from transformers import BatchEncoding
from transformers import PreTrainedTokenizer
@dataclass
class DataCollatorForData2TextLanguageModeling:
"""
Data collator used for language modeling.
- collates batches of tensors, honoring their tokenizer's pad_token
- preprocesses batches for masked language modeling
"""
tokenizer: PreTrainedTokenizer
mlm: bool = True
format_mode: str = 'cat'
mlm_probability: float = 0.15
def __call__(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples]
input_ids, labels = zip(*examples)
pad_token = self.tokenizer.pad_token_id
if pad_token is None:
pad_token = -100
input_ids = self._tensorize_batch(input_ids)
labels = self._tensorize_batch(labels)
# labels = torch.cat([
# batch[:, 1:],
# torch.tensor([[pad_token]] * batch.shape[0], dtype=torch.long)
# ], dim=1)
# print(labels)
attention_mask = torch.where((input_ids != pad_token), 1., 0.)
input_ids[input_ids == pad_token] = 0
return {"input_ids": input_ids, "labels": labels, 'attention_mask': attention_mask}
def _tensorize_batch(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> torch.Tensor:
# In order to accept both lists of lists and lists of Tensors
if isinstance(examples[0], (list, tuple)):
examples = [torch.tensor(e, dtype=torch.long) for e in examples]
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)
else:
pad_token = self.tokenizer.pad_token_id
if pad_token is None:
pad_token = -100
return pad_sequence(examples, batch_first=True, padding_value=pad_token)
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels