-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_dataset.py
83 lines (73 loc) · 2.58 KB
/
load_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import glob
import numpy as np
import os
import tensorflow.compat.v1 as tf
import tqdm
def load_dataset(enc, path, combine, encoding=None):
paths = []
if os.path.isfile(path):
# Simple file
paths.append(path)
elif os.path.isdir(path):
# Directory
for (dirpath, _, fnames) in os.walk(path):
for fname in fnames:
paths.append(os.path.join(dirpath, fname))
else:
# Assume glob
paths = glob.glob(path)
token_chunks = []
raw_text = ''
for path in tqdm.tqdm(paths):
if path.endswith('.npz'):
# Pre-encoded
with np.load(path) as npz:
for item in npz.files:
token_chunks.append(npz[item])
else:
# Plain text
with open(path, 'r', encoding=encoding) as fp:
raw_text += fp.read()
if len(raw_text) >= combine:
tokens = np.stack(enc.encode(raw_text))
token_chunks.append(tokens)
raw_text = ''
else:
raw_text += '<|endoftext|>'
if raw_text:
tokens = np.stack(enc.encode(raw_text))
token_chunks.append(tokens)
return token_chunks
def binary_search(f, lo, hi):
if f(lo) or not f(hi):
return None
while hi > lo + 1:
mid = (lo + hi) // 2
if f(mid):
hi = mid
else:
lo = mid
return hi
class Sampler(object):
"""Fairly samples a slice from a set of variable sized chunks.
'Fairly' means that the distribution is the same as sampling from one concatenated chunk,
but without crossing chunk boundaries."""
def __init__(self, chunks, seed=None):
self.chunks = chunks
self.total_size = sum(chunk.shape[0] for chunk in chunks)
self.boundaries = [0]
for i in range(len(chunks)):
self.boundaries.append(self.boundaries[-1] + chunks[i].shape[0])
self.rs = np.random.RandomState(seed=seed)
def sample(self, length):
assert length < self.total_size // len(
self.chunks
), "Dataset files are too small to sample {} tokens at a time".format(
length)
while True:
index = self.rs.randint(0, self.total_size - length - 1)
i = binary_search(lambda j: self.boundaries[j] > index, 0,
len(self.boundaries) - 1) - 1
if self.boundaries[i + 1] > index + length:
within_chunk = index - self.boundaries[i]
return self.chunks[i][within_chunk:within_chunk + length]