-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset_cifar10c.py
52 lines (40 loc) · 1.48 KB
/
dataset_cifar10c.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import numpy as np
import os
import PIL
import torch
import torchvision
from PIL import Image
from torch.utils.data import Subset
from torchvision import datasets
def load_txt(path :str) -> list:
return [line.rstrip('\n') for line in open(path)]
corruptions = load_txt('./corruptions.txt')
class CIFAR10C(datasets.VisionDataset):
def __init__(self, root :str, name :str,
transform=None, target_transform=None):
assert name in corruptions
super(CIFAR10C, self).__init__(
root, transform=transform,
target_transform=target_transform
)
data_path = os.path.join(root, name + '.npy')
target_path = os.path.join(root, 'labels.npy')
self.data = np.load(data_path)
self.targets = np.load(target_path)
def __getitem__(self, index):
img, targets = self.data[index], self.targets[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
targets = self.target_transform(targets)
return img, targets
def __len__(self):
return len(self.data)
def extract_subset(dataset, num_subset :int, random_subset :bool):
if random_subset:
random.seed(0)
indices = random.sample(list(range(len(dataset))), num_subset)
else:
indices = [i for i in range(num_subset)]
return Subset(dataset, indices)