-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathasyn_dataloader.py
97 lines (77 loc) · 2.62 KB
/
asyn_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
'''
Accelerate the loading speed of pre-training data,
but the effect seems not obvious
'''
import torch
from queue import Queue
from threading import Thread
class CudaDataLoader:
def __init__(self, loader, device, queue_size=2):
self.device = device
self.queue_size = queue_size
self.loader = loader
self.load_stream = torch.cuda.Stream(device=device)
self.queue = Queue(maxsize=self.queue_size)
self.idx = 0
self.worker = Thread(target=self.load_loop)
self.worker.setDaemon(True)
self.worker.start()
def load_loop(self):
# The loop that will load into the queue in the background
torch.cuda.set_device(self.device)
while True:
for i, sample in enumerate(self.loader):
self.queue.put(self.load_instance(sample))
def load_instance(self, sample):
if torch.is_tensor(sample):
with torch.cuda.stream(self.load_stream):
return sample.to(self.device, non_blocking=True)
elif sample is None or type(sample) == str:
return sample
elif isinstance(sample, dict):
return {k: self.load_instance(v) for k, v in sample.items()}
else:
return [self.load_instance(s) for s in sample]
def __iter__(self):
self.idx = 0
return self
def __next__(self):
if not self.worker.is_alive() and self.queue.empty():
self.idx = 0
self.queue.join()
self.worker.join()
raise StopIteration
elif self.idx >= len(self.loader):
self.idx = 0
raise StopIteration
else:
out = self.queue.get()
self.queue.task_done()
self.idx += 1
return out
def next(self):
return self.__next__()
def __len__(self):
return len(self.loader)
@property
def sampler(self):
return self.loader.sampler
@property
def dataset(self):
return self.loader.dataset
class _RepeatSampler(object):
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)