-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataset.py
129 lines (98 loc) · 4.75 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import numpy as np
import torch
# import pandas as pd
from torch.utils.data import Dataset
from util import *
import os.path
from os import path
class FeaturesCls(Dataset):
def __init__(self, opt, features=None, labels=None, val=False, split='seen', classes_to_train=None):
self.root = f"{opt.dataroot}"
self.opt = opt
self.classes_to_train = classes_to_train
self.classid_tolabels = None
self.features = features
self.labels = labels
if self.classes_to_train is not None:
self.classid_tolabels = {label: i for i, label in enumerate(self.classes_to_train)}
print(f"class ids for unseen classifier {self.classes_to_train}")
if 'test' in split:
self.loadRealFeats(syn_feature=features, syn_label=labels, split=split)
def loadRealFeats(self, syn_feature=None, syn_label=None, split='train'):
if 'test' in split:
self.features = np.load(f"{self.root}/{self.opt.testsplit}_feats.npy")
self.labels = np.load(f"{self.root}/{self.opt.testsplit}_labels.npy")
print(f"{len(self.labels)} testsubset {self.opt.testsplit} features loaded")
# import pdb; pdb.set_trace()
def replace(self, features=None, labels=None):
self.features = features
self.labels = labels
self.ntrain = len(self.labels)
print(f"\n=== Replaced new batch of Syn Feats === \n")
def __getitem__(self, idx):
batch_feature = self.features[idx]
batch_label = self.labels[idx]
if self.classid_tolabels is not None:
batch_label = self.classid_tolabels[batch_label]
return batch_feature, batch_label
def __len__(self):
return len(self.labels)
class FeaturesGAN():
def __init__(self, opt):
self.root = f"{opt.dataroot}"
self.opt = opt
# self.attribute = np.load(opt.class_embedding)
print("loading numpy arrays")
self.all_features = np.load(f"{self.root}/{self.opt.trainsplit}_feats.npy")
self.all_labels = np.load(f"{self.root}/{self.opt.trainsplit}_labels.npy")
mean_path = f"{self.root}/{self.opt.trainsplit}_mean.npy"
print(f'loaded data from {self.opt.trainsplit}')
self.pos_inds = np.where(self.all_labels>0)[0]
self.neg_inds = np.where(self.all_labels==0)[0]
unique_labels = np.unique(self.all_labels)
self.num_bg_to_take = len(self.pos_inds)//len(unique_labels)
print(f"loaded {len(self.pos_inds)} fg labels")
print(f"loaded {len(self.neg_inds)} bg labels ")
print(f"bg indexes for each epoch {self.num_bg_to_take}")
self.features_mean = np.zeros((max(unique_labels) + 1 , self.all_features.shape[1]))
# if path.exists(mean_path):
# self.features_mean = np.load(mean_path)
# else:
# for label in unique_labels:
# label_inds = np.where(self.all_labels==label)[0]
# self.features_mean[label] = self.all_features[label_inds].mean(axis=0)
# np.save(mean_path, self.features_mean)
def epochData(self, include_bg=False):
fg_inds = np.random.permutation(self.pos_inds)
inds = np.random.permutation(fg_inds)[:int(self.opt.gan_epoch_budget)]
if include_bg:
bg_inds = np.random.permutation(self.neg_inds)[:self.num_bg_to_take]
inds = np.random.permutation(np.concatenate((fg_inds, bg_inds)))[:int(self.opt.gan_epoch_budget)]
features = self.all_features[inds]
labels = self.all_labels[inds]
return features, labels
def getBGfeats(self, num=1000):
bg_inds = np.random.permutation(self.neg_inds)[:num]
print(f"{len(bg_inds)} ")
return self.all_features[bg_inds], self.all_labels[bg_inds]
def __len__(self):
return len(self.all_labels)
##todo seen features
def getseenfeats(self, num=1000):
seen_features_list = []
seen_lables_list = []
for i in range(1, 17):
seen_i_inds = np.where(self.all_labels==i)[0]
seen_i_inds_radn = np.random.permutation(seen_i_inds)[:num]
seen_i_features = self.all_features[seen_i_inds_radn]
seen_i_lables = self.all_labels[seen_i_inds_radn]
seen_features_list.append(seen_i_features)
seen_lables_list.append(seen_i_lables)
seen_features = np.concatenate(seen_features_list, 0)
seen_lables = np.concatenate(seen_lables_list, 0)
return seen_features, seen_lables
##todo bg memory bank
def getBGfeats_memory(self, num=1000):
bg_inds = np.random.permutation(self.neg_inds)[:len(self.neg_inds)]
print(f"{len(bg_inds)} ")
return self.all_features[bg_inds], self.all_labels[bg_inds]