-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader.py
164 lines (142 loc) · 6.7 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# /usr/bin/env python
# coding=utf-8
"""Dataloader"""
import os
import json
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset
from transformers import BertTokenizer
from dataloader_utils import read_examples, convert_examples_to_features, read_raw_examples
class FeatureDataset(Dataset):
"""Pytorch Dataset for InputFeatures
"""
def __init__(self, features):
self.features = features
def __len__(self) -> int:
return len(self.features)
def __getitem__(self, index):
return self.features[index]
class CustomDataLoader(object):
def __init__(self, params):
self.params = params
self.train_batch_size = params.train_batch_size
self.val_batch_size = params.val_batch_size
self.test_batch_size = params.test_batch_size
self.data_dir = params.data_dir
self.max_seq_length = params.max_seq_length
self.tokenizer = BertTokenizer(vocab_file=os.path.join(params.bert_model_dir, 'vocab.txt'),
do_lower_case=False)
self.data_cache = params.data_cache
@staticmethod
def collate_fn_train(features):
"""将InputFeatures转换为Tensor
Args:
features (List[InputFeatures])
Returns:
tensors (List[Tensors])
"""
input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
seq_tags = torch.tensor([f.seq_tag for f in features], dtype=torch.long)
poten_relations = torch.tensor([f.relation for f in features], dtype=torch.long)
corres_tags = torch.tensor([f.corres_tag for f in features], dtype=torch.long)
rel_tags = torch.tensor([f.rel_tag for f in features], dtype=torch.long)
tensors = [input_ids, attention_mask, seq_tags, poten_relations, corres_tags, rel_tags]
return tensors
@staticmethod
def collate_fn_test(features):
"""将InputFeatures转换为Tensor
Args:
features (List[InputFeatures])
Returns:
tensors (List[Tensors])
"""
input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
triples = [f.triples for f in features]
input_tokens = [f.input_tokens for f in features]
tensors = [input_ids, attention_mask, triples, input_tokens]
return tensors
@staticmethod
def collate_fn_infer(features):
"""将InputFeatures转换为Tensor
Args:
features (List[InputFeatures])
Returns:
tensors (List[Tensors])
"""
input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
input_tokens = [f.input_tokens for f in features]
tensors = [input_ids, attention_mask, input_tokens]
return tensors
def get_features(self, data_sign, ex_params):
"""convert InputExamples to InputFeatures
:param data_sign: 'train', 'val' or 'test'
"""
print("=*=" * 10)
print("Loading {} data...".format(data_sign))
# get features
cache_path = os.path.join(self.data_dir, "{}.cache.{}".format(data_sign, str(self.max_seq_length)))
if os.path.exists(cache_path) and self.data_cache:
features = torch.load(cache_path)
else:
# get relation to idx
with open(self.data_dir / f'rel2id.json', 'r', encoding='utf-8') as f_re:
rel2idx = json.load(f_re)[-1]
# get examples
if data_sign in ("train", "val", "temp", "test", "pseudo", 'EPO', 'SEO', 'SOO', 'Normal', '1', '2', '3', '4', '5'):
examples = read_examples(self.data_dir, data_sign=data_sign, rel2idx=rel2idx)
elif data_sign == "inference":
examples = read_raw_examples(self.data_dir, data_sign=data_sign, rel2idx=rel2idx)
else:
raise ValueError("please notice that the data can only be train/val/test!!")
features = convert_examples_to_features(self.params, examples, self.tokenizer, rel2idx, data_sign,
ex_params)
# save data
if self.data_cache:
torch.save(features, cache_path)
return features
def get_dataloader(self, data_sign="train", ex_params=None):
"""construct dataloader
:param data_sign: 'train', 'val' or 'test'
"""
# InputExamples to InputFeatures
features = self.get_features(data_sign=data_sign, ex_params=ex_params)
dataset = FeatureDataset(features)
print(f"{len(features)} {data_sign} data loaded!")
print("=*=" * 10)
# construct dataloader
# RandomSampler(dataset) or SequentialSampler(dataset)
if data_sign == "train":
datasampler = RandomSampler(dataset)
dataloader = DataLoader(dataset, sampler=datasampler, batch_size=self.train_batch_size,
collate_fn=self.collate_fn_train)
elif data_sign == "temp":
datasampler = RandomSampler(dataset)
dataloader = DataLoader(dataset, sampler=datasampler, batch_size=self.val_batch_size,
collate_fn=self.collate_fn_train)
elif data_sign == "val":
datasampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, sampler=datasampler, batch_size=self.val_batch_size,
collate_fn=self.collate_fn_test)
elif data_sign in ("test", "pseudo", 'EPO', 'SEO', 'SOO', 'Normal', '1', '2', '3', '4', '5'):
datasampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, sampler=datasampler, batch_size=self.test_batch_size,
collate_fn=self.collate_fn_test)
elif data_sign == "inference":
datasampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, sampler=datasampler, batch_size=1,
collate_fn=self.collate_fn_infer)
else:
raise ValueError("please notice that the data can only be train/val/test !!")
return dataloader
if __name__ == '__main__':
from utils import Params
params = Params(corpus_type='Job')
ex_params = {
'ensure_relpre': True
}
dataloader = CustomDataLoader(params)
feats = dataloader.get_features(ex_params=ex_params, data_sign='test')
print(feats[7].input_tokens)