-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloaders.py
101 lines (89 loc) · 3.11 KB
/
dataloaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import tiktoken
import torch
import os
from utils import log, IS_MASTER
import numpy as np
from tqdm import tqdm
def load_np(file: str):
data = np.load(file)
data_tensor = torch.tensor(data, dtype=torch.long)
return data_tensor
class DataLoader:
def __init__(
self,
path: str,
batch_size: int,
block_size: int,
split: str,
rank: int = 0,
world_size: int = 1,
limit_files: int = -1,
) -> None:
assert split in ("train", "validation", "val")
files = [f for f in os.listdir(path) if split in f]
log(f"found {len(files)} file(s) for split {split}. Loading...")
if limit_files > 0:
log(
f"will use a max of {limit_files} file(s) since it is explicitly asked."
)
files = files[:limit_files]
if IS_MASTER:
files = tqdm(files)
self.shards = [load_np(os.path.join(path, file_name)) for file_name in files]
log("done loading")
self.batch_size = batch_size
self.block_size = block_size
self.rank = rank
self.world_size = world_size
self.reset()
def next_batch(self):
B = self.batch_size
T = self.block_size
current_shard = self.shards[self.current_shard_ix]
buf = current_shard[self.current_pos : B * T + self.current_pos + 1]
x = buf[:-1].view(-1, T)
y = buf[1:].view(-1, T)
self.current_pos += B * T * self.world_size
if self.current_pos + (B * T * self.world_size + 1) > current_shard.size(0):
self._reset_pos()
self.current_shard_ix = (self.current_shard_ix + 1) % len(self.shards)
return x, y
def reset(self):
self._reset_pos()
self.current_shard_ix = 0
def _reset_pos(self):
self.current_pos = self.rank * self.batch_size * self.block_size
class DataLoaderTokenizer:
def __init__(
self,
file_name: str,
batch_size: int,
block_size: int,
model_name: str = "gpt2",
device: str = None,
rank: int = 0,
world_size: int = 1,
) -> None:
with open(file_name, "r") as f:
text = f.read()
tokenizer = tiktoken.get_encoding(model_name)
self.data = torch.tensor(tokenizer.encode(text))
if device:
self.data = self.data.to(device)
log(f"Loaded {self.data.size(0)} tokens.")
self.n_tokens = self.data.size(0)
self.batch_size = batch_size
self.block_size = block_size
self.rank = rank
self.world_size = world_size
self.current_pos = self.rank * self.batch_size * self.block_size
def next_batch(self):
B = self.batch_size
T = self.block_size
buf = self.data[self.current_pos : B * T + self.current_pos + 1]
x = buf[:-1].view(-1, T)
y = buf[1:].view(-1, T)
self.current_pos += B * T * self.world_size
if self.current_pos + (B * T * self.world_size + 1) > self.data.size(0):
self.current_pos = self.rank * self.batch_size * self.block_size
return x, y