-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataset.py
93 lines (78 loc) · 3.54 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from tqdm.auto import tqdm
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, df, tokenizer, max_len=150, add_special_tokens=True, class_token_pos=None):
self.data = []
for row in tqdm(df.itertuples(), total=df.shape[0]):
encoding = tokenize_mlub(
row.text_truncated,
tokenizer,
max_len=max_len,
add_special_tokens=add_special_tokens,
class_token_pos=class_token_pos
)
# encoding = {k:v[0] for k, v in encoding.items()}
encoding["labels"] = torch.tensor(row.synset_index, dtype=torch.long)
# del encoding["start_end_mask"]
self.data.append(encoding)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
def tokenize_mlub(text, tokenizer, max_len, add_special_tokens, class_token_pos=None):
input_ids = []
attention_mask = []
start = end = None
if class_token_pos is not None and class_token_pos == "start":
input_ids += [tokenizer.cls_token_id]
for index, tok in enumerate(text.split()):
if "#" in tok:
tok = tok.split("#")[0]
start = len(input_ids)
end = len(input_ids) + len(tokenizer.encode(tok, add_special_tokens=add_special_tokens))
if class_token_pos is not None and class_token_pos == "synset":
input_ids += [tokenizer.cls_token_id]
end += 1
input_ids += tokenizer.encode(tok + " ", add_special_tokens=add_special_tokens)
attention_mask = [1] * len(input_ids)
if len(input_ids) > max_len:
raise Exception("input_ids is longer than max_len, please increase max_len")
attention_mask += [0] * (max_len - len(input_ids))
if tokenizer.is_fast and tokenizer.pad_token_id < tokenizer.vocab_size - 1:
padding_id = tokenizer.pad_token_id
else:
padding_id = tokenizer.eos_token_id
if "gpt" in tokenizer.name_or_path:
padding_id = 1 # which is <pad>
if "tugstugi" in tokenizer.name_or_path or "mlub" in tokenizer.name_or_path:
padding_id = 3
input_ids.extend([padding_id] * (max_len - len(input_ids)))
start_end_mask = [0] * max_len
for i in range(start, end):
start_end_mask[i] = 1
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"start_end_mask": torch.tensor(start_end_mask, dtype=torch.long),
}
def truncate_text(df, effective_len = 30):
list_truncated_text = []
for row in df.itertuples():
# get query token location
tokens = row.text.split()
synset_query = [(tok, index) for index, tok in enumerate(tokens) if "#" in tok]; assert len(synset_query) == 1
query_word, query_location = synset_query[0]
# truncate
begin = max(query_location - effective_len//2, 0)
end = min(query_location + effective_len//2+1, len(tokens))
tokens = tokens[begin: end]
# store
list_truncated_text.append(" ".join(tokens))
df["text_truncated"] = list_truncated_text
return df
def get_index2synsetid_dicts(synset_word, dict_synset_meaning):
"""index2id, id2index = get_index2synsetid_dicts("ам")"""
# construct index to id dictionaries
ids = sorted(list(dict_synset_meaning[synset_word].keys()))
return {i:index for i, index in enumerate(ids)}, {index:i for i, index in enumerate(ids)}