-
Notifications
You must be signed in to change notification settings - Fork 0
/
varying_size_dataloader.py
66 lines (53 loc) · 2.16 KB
/
varying_size_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import random
import torch
from baumbauen.utils import pad_collate_fn
from torch.utils.data.sampler import BatchSampler
from baumbauen.utils import SimilarSizeSampler
from baumbauen.utils import BucketingSampler
class TestDataset(Dataset):
def __init__(self, random_lengths=False, num_items=100):
self.num_items = num_items
self.data = []
self.labels = []
if not random_lengths:
self.data = [
torch.rand((50, 1, 10))
]
else:
for i in range(num_items):
len = random.randrange(2, 60)
self.data.append(torch.rand((50, len, 6)))
self.labels.append(torch.rand(50, len, 1))
print("initialized")
def __len__(self):
return self.num_items
def __getitem__(self, index):
return self.data[index], self.labels[index]
def main():
dataset = TestDataset(random_lengths=True)
sim_size_sampler = SimilarSizeSampler(dataset, replacement=False, batch_size=3)
bucketing_sampler = BucketingSampler(dataset, batch_size=3)
# Prep dataloaders
sim_dataloader = DataLoader(dataset,
batch_size=1,
batch_sampler=sim_size_sampler,
num_workers=0,
collate_fn=pad_collate_fn,
drop_last=False,
pin_memory=False)
buck_dataloader = DataLoader(dataset,
batch_size=1,
batch_sampler=bucketing_sampler,
num_workers=0,
collate_fn=pad_collate_fn,
drop_last=False,
pin_memory=False)
# Prep Batch samplers
batch_sampler_sim = BatchSampler(sim_size_sampler, batch_size=3, drop_last=False)
batch_sampler_bucket = BatchSampler(bucketing_sampler, batch_size=3, drop_last=False)
for batch in BatchSampler:
print(batch)
if __name__ == "__main__":
main()