-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
68 lines (54 loc) · 2.21 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# define dataset class here which is used to feed in the model for training and eval purposes , use TF.data.Dataset api.
import transformers
# import tensorflow as tf
import numpy as np
import pandas as pd
import torch
from configs import config
from nltk.tokenize.treebank import TreebankWordDetokenizer
class dataset(torch.utils.data.Dataset):
def __init__(self, data_path,model):
self.tokenizer = config.TOKENIZER
self.max_len = config.MAX_LEN
self.data_path = data_path
self.data=pd.read_csv(data_path)
self.model=model
def get_target(self, data):
text = ' '+data["text"].strip()+' '
phrases = ' '+data['phrases'].strip()+' '
if self.model==1:
phrases=phrases.replace('[SEP]','</s></s>')
encoded_text = self.tokenizer.encode_plus(
text,
max_length=self.max_len,
add_special_tokens=True,
return_attention_mask=True,
return_token_type_ids=True,
return_tensors='pt',
padding='max_length',
truncation=True
)
encoded_phrases = self.tokenizer.encode_plus(
phrases,
max_length=self.max_len,
add_special_tokens=True,
return_attention_mask=True,
return_token_type_ids=True,
return_tensors='pt',
padding='max_length',
truncation=True
)
input_ids = encoded_text.input_ids[0]
token_type_ids = encoded_text.token_type_ids[0]
attention_mask = encoded_text.attention_mask[0]
p_input_ids = encoded_phrases.input_ids[0]
p_token_type_ids = encoded_phrases.token_type_ids[0]
p_attention_mask = encoded_phrases.attention_mask[0]
return {"orig": text, "input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask,
"phrases": text, "p_input_ids": p_input_ids, "p_token_type_ids": p_token_type_ids, "p_attention_mask": p_attention_mask
}
def __len__(self):
'Denotes the total number of samples'
return len(self.data)
def __getitem__(self, index):
return self.get_target(self.data.iloc[index])