-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_vtab1k.py
103 lines (88 loc) · 3.95 KB
/
get_vtab1k.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from task_adaptation.data import caltech
from task_adaptation.data import cifar
from task_adaptation.data import clevr
from task_adaptation.data import diabetic_retinopathy
from task_adaptation.data import dmlab
from task_adaptation.data import dsprites
from task_adaptation.data import dtd
from task_adaptation.data import eurosat
from task_adaptation.data import kitti
from task_adaptation.data import oxford_flowers102
from task_adaptation.data import oxford_iiit_pet
from task_adaptation.data import patch_camelyon
from task_adaptation.data import resisc45
from task_adaptation.data import smallnorb
from task_adaptation.data import sun397
from task_adaptation.data import svhn
from task_adaptation.data.base import compose_preprocess_fn
from task_adaptation.registry import Registry
import os
import os.path as osp
from PIL import Image
import numpy as np
import tensorflow as tf
import logging
import resource
"""
Copyright (c) 2022 Yuanhan Zhang
"""
logging.basicConfig(level=logging.DEBUG)
dataset_config = [
['caltech101', dict()],
['cifar', dict(num_classes=100)],
['dtd', dict()],
['oxford_flowers102', dict()],
['oxford_iiit_pet', dict()],
['patch_camelyon', dict()],
['sun397', dict()],
['svhn', dict()],
['resisc45', dict()],
['eurosat', dict()],
['dmlab', dict()],
['kitti', dict(task='closest_vehicle_distance')],
['smallnorb', dict(predicted_attribute='label_azimuth',dataset_postfix='azi')],
['smallnorb', dict(predicted_attribute='label_elevation',dataset_postfix='ele')],
['dsprites', dict(predicted_attribute='label_x_position',num_classes=16,dataset_postfix='loc')],
['dsprites', dict(predicted_attribute='label_orientation',num_classes=16,dataset_postfix='ori')],
['clevr', dict(task='closest_object_distance',dataset_postfix='dist')],
['clevr', dict(task='count_all',dataset_postfix='count')],
['diabetic_retinopathy', dict(config='btgraham-300')],
]
def parse_args():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, required=True)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
seed = args.seed
logging.info(f'Seed: {seed}')
# Set random seed
np.random.seed(seed)
tf.random.set_seed(seed)
low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
data_root = osp.expanduser('./vtab-1k')
for dataset_name, dataset_params in dataset_config:
dataset_postfix = dataset_params.pop('dataset_postfix', None)
data_cls = Registry.lookup(f'data.{dataset_name}')(**dataset_params)
if dataset_postfix is not None:
dataset_name = dataset_name + '_' + dataset_postfix
os.makedirs(f'{data_root}/{dataset_name}', exist_ok=True)
os.makedirs(f'{data_root}/{dataset_name}/images', exist_ok=True)
logging.info(f'{dataset_name} started.')
for split_name in ['train800', 'val200', 'test', 'train800val200']:
data = data_cls._get_dataset_split(split_name=split_name, shuffle_files=False)
base_preprocess_fn = compose_preprocess_fn(data_cls._image_decoder, data_cls._base_preprocess_fn)
data = data.map(base_preprocess_fn, data_cls._num_preprocessing_threads)
os.makedirs(f'{data_root}/{dataset_name}/images/{split_name}', exist_ok=True)
with open(f'{data_root}/{dataset_name}/{split_name}.txt', 'w') as f:
for i, item in enumerate(data):
image_path = f'images/{split_name}/{i:06d}.jpg'
label = item['label'].numpy().item()
f.write(f'{image_path} {label}\n')
image = item['image'].numpy()
image = Image.fromarray(image)
image.save(f'{data_root}/{dataset_name}/{image_path}')
logging.info(f'{dataset_name} is done.')