-
Notifications
You must be signed in to change notification settings - Fork 6
/
dataset.py
32 lines (27 loc) · 1.15 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import random
class ImdbDataset:
def __init__(self, data, tokenizer):
self.data = data
self.tokenizer = tokenizer
def return_train_dataset(self):
# Perform text encoding
train_dataset = self.data["train"].map(self.preprocess_function, batched=True)
return train_dataset
def return_test_dataset(self, eval_ratio = 0.1):
random.seed(42)
# Perform text encoding
test_dataset = self.data["test"].map(self.preprocess_function, batched=True)
# Create an evaluation dataset for evaluation during training
# Due to the large number of test samples, only take a sample of 1% of the test dataset for evaluation
total_samples = len(test_dataset)
eval_samples = int(eval_ratio * total_samples)
eval_indices = random.sample(range(total_samples), eval_samples)
eval_dataset = test_dataset.select(eval_indices)
return test_dataset, eval_dataset
def preprocess_function(self, examples):
samples = self.tokenizer(
examples['text'],
truncation=True
)
samples.pop('attention_mask')
return samples