-
Notifications
You must be signed in to change notification settings - Fork 4
/
get_train_test_quick_draw.py
30 lines (26 loc) · 1.01 KB
/
get_train_test_quick_draw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import pickle
import os
from tqdm import tqdm
import random
train_x = []
valid_x = []
test_x = []
root = '/scratche/home/<user>/processed_quick_draw/'
for r, d, f in os.walk(root):
for dr in tqdm(d):
for r2,d2,f2 in os.walk(os.path.join(root, dr)):
for file in tqdm(f2):
if '.pkl' in file:
toss_1 = random.random()
if toss_1 > 0.8:
toss_2 = random.random()
if toss_2 > 0.5:
valid_x.append(os.path.join(os.path.join(root,dr), file))
else:
test_x.append(os.path.join(os.path.join(root,dr), file))
else:
train_x.append(os.path.join(os.path.join(root,dr), file))
print(len(train_x), len(test_x), len(valid_x))
data = {'train_x': train_x, 'test_x': test_x, 'valid_x': valid_x}
print(positive_samples.size())
pickle.dump(data, open('processed_quick_draw_paths.pkl', 'wb'))