-
Notifications
You must be signed in to change notification settings - Fork 1
/
ood_utils.py
164 lines (143 loc) · 7.36 KB
/
ood_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
Helper functions or OOD experiments.
"""
import torch
from torchvision.datasets import CelebA, CIFAR10, SVHN, DTD, CIFAR100
from torch.utils.data import Subset, DataLoader
from torchvision import transforms
from mpi4py import MPI
import argparse
def get_interpolation_mode(mode):
if mode == 'bilinear':
return transforms.InterpolationMode.BILINEAR
elif mode =='nearest':
return transforms.InterpolationMode.NEAREST
elif mode =='nearest_exact':
return transforms.InterpolationMode.NEAREST_EXACT
elif mode =='bicubic':
return transforms.InterpolationMode.BICUBIC
elif mode =='box':
return transforms.InterpolationMode.BOX
elif mode =='hamming':
return transforms.InterpolationMode.HAMMING
elif mode =='lanczos':
return transforms.InterpolationMode.LANCZOS
else:
print('not a valid interpolation mode')
exit()
def build_subset_per_process(dataset):
"""
Partitions dataset so each process (GPU) trains on a unique subset.
"""
n_processes = MPI.COMM_WORLD.Get_size()
n_current_rank = MPI.COMM_WORLD.Get_rank()
n_indices = torch.arange(0, len(dataset), dtype=int)
indices_chunks = torch.chunk(n_indices, chunks=n_processes)
indices_for_current_rank = indices_chunks[n_current_rank]
subset = Subset(dataset, indices_for_current_rank)
return subset
def yield_(loader):
while True:
yield from loader
def load_celeba(data_dir, batch_size, image_size, train=False, interpolation_mode='bilinear', shuffle=True):
transform = transforms.Compose([
transforms.CenterCrop(140),
transforms.Resize((image_size, image_size), interpolation=get_interpolation_mode(interpolation_mode)),
transforms.ToTensor(),
transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])
dataset = CelebA(data_dir, download=True, transform=transform, split='train' if train else 'test')
subset = build_subset_per_process(dataset)
loader = DataLoader(subset, batch_size=batch_size, shuffle=shuffle, num_workers=1, drop_last=False)
return loader
def load_celeba_resized(data_dir, batch_size, image_size, train=False, interpolation_mode='bilinear', shuffle=True):
transform = transforms.Compose([
transforms.CenterCrop(140),
transforms.Resize((32, 32), interpolation=get_interpolation_mode(interpolation_mode)),
transforms.Resize((image_size, image_size), interpolation=get_interpolation_mode(interpolation_mode)),
transforms.ToTensor(),
transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])
dataset = CelebA(data_dir, download=True, transform=transform, split='train' if train else 'test')
subset = build_subset_per_process(dataset)
loader = DataLoader(subset, batch_size=batch_size, shuffle=shuffle, num_workers=1, drop_last=False)
return loader
def load_cifar10(data_dir, batch_size, image_size, train=False, interpolation_mode='bilinear', shuffle=True):
transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=get_interpolation_mode(interpolation_mode)),
transforms.ToTensor(),
transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])
dataset = CIFAR10(data_dir, download=True, transform=transform, train=train)
subset = build_subset_per_process(dataset)
loader = DataLoader(subset, batch_size=batch_size, shuffle=shuffle, num_workers=1, drop_last=False)
return loader
def load_svhn(data_dir, batch_size, image_size, train=False, interpolation_mode='bilinear', shuffle=True):
transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=get_interpolation_mode(interpolation_mode)),
transforms.ToTensor(),
transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])
dataset = SVHN(data_dir, download=True, transform=transform, split='train' if train else 'test')
subset = build_subset_per_process(dataset)
loader = DataLoader(subset, batch_size=batch_size, shuffle=shuffle, num_workers=1, drop_last=False)
return loader
def load_textures(data_dir, batch_size, image_size, train=False, interpolation_mode='bilinear', shuffle=True):
transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=get_interpolation_mode(interpolation_mode)),
transforms.ToTensor(),
transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])
dataset = DTD(data_dir, download=True, transform=transform, split='train' if train else 'test')
subset = build_subset_per_process(dataset)
loader = DataLoader(subset, batch_size=batch_size, shuffle=shuffle, num_workers=1, drop_last=False)
return loader
def load_textures_resized(data_dir, batch_size, image_size, train=False, interpolation_mode='bilinear', shuffle=True):
transform = transforms.Compose([
transforms.Resize((32, 32), interpolation=get_interpolation_mode(interpolation_mode)),
transforms.Resize((image_size, image_size), interpolation=get_interpolation_mode(interpolation_mode)),
transforms.ToTensor(),
transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])
dataset = DTD(data_dir, download=True, transform=transform, split='train' if train else 'test')
subset = build_subset_per_process(dataset)
loader = DataLoader(subset, batch_size=batch_size, shuffle=shuffle, num_workers=1, drop_last=False)
return loader
def load_cifar100(data_dir, batch_size, image_size, train=False, interpolation_mode='bilinear', shuffle=True):
transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=get_interpolation_mode(interpolation_mode)),
transforms.ToTensor(),
transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])
dataset = CIFAR100(data_dir, download=True, transform=transform, train=train)
subset = build_subset_per_process(dataset)
loader = DataLoader(subset, batch_size=batch_size, shuffle=shuffle, num_workers=1, drop_last=False)
return loader
def load_data(dataset, data_dir, batch_size, image_size, train, interpolation_mode='bilinear', shuffle=True):
if dataset == "cifar10":
dataloader = load_cifar10(data_dir, batch_size, image_size, train, interpolation_mode, shuffle)
elif dataset == "celeba":
dataloader = load_celeba(data_dir, batch_size, image_size, train, interpolation_mode, shuffle)
elif dataset == "celeba_resized":
dataloader = load_celeba_resized(data_dir, batch_size, image_size, train, interpolation_mode, shuffle)
elif dataset == "svhn":
dataloader = load_svhn(data_dir, batch_size, image_size, train, interpolation_mode, shuffle)
elif dataset == "textures":
dataloader = load_textures(data_dir, batch_size, image_size, train, interpolation_mode, shuffle)
elif dataset == "textures_resized":
dataloader = load_textures_resized(data_dir, batch_size, image_size, train, interpolation_mode, shuffle)
elif dataset == "cifar100":
dataloader = load_cifar100(data_dir, batch_size, image_size, train, interpolation_mode, shuffle)
else:
print("Wrong ID dataset!")
exit()
return dataloader
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace