-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
64 lines (44 loc) · 1.75 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 26 08:45:25 2020
@author: mooselumph
"""
from os.path import splitext
from os import listdir
import numpy as np
from glob import glob
import torch
from torch.utils.data import Dataset
import logging
class BasicDataset(Dataset):
def __init__(self, model_dir=None, gather_dir=None):
self.model_dir = model_dir
self.gather_dir = gather_dir
d = model_dir if model_dir else gather_dir
self.ids = [splitext(file)[0] for file in listdir(d)
if not file.startswith('.')]
logging.info(f'Creating dataset with {len(self.ids)} examples')
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
idx = self.ids[i]
d = dict()
if self.model_dir:
model_file = glob(self.model_dir + idx + '*')
assert len(model_file) == 1, \
f'Either no model or multiple models found for the ID {idx}: {model_file}'
model = np.load(model_file[0])[np.newaxis,:,:]
model -= np.min(model,axis=None)
model /= np.max(model,axis=None)
model = model*2-1
model = torch.tensor(model,dtype=torch.float32)
d = model
if self.gather_dir:
gather_file = glob(self.gather_dir + idx + '*')
assert len(gather_file) == 1, \
f'Either no gather file or multiple files found for the ID {idx}: {gather_file}'
gather = torch.tensor(np.load(gather_file[0]),dtype=torch.float32)
d = gather
if self.gather_dir and self.model_dir:
d = {'model':model,'gather':gather}
return d