-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtinyimagenetloader.py
78 lines (60 loc) · 2.54 KB
/
tinyimagenetloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# -*- coding: utf-8 -*-
"""TinyImageNetLoader.ipynb
Automatically generated by Colaboratory.
"""
#loads images as 3*64*64 tensors
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
!unzip -q tiny-imagenet-200.zip
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import os, glob
from torchvision.io import read_image, ImageReadMode
batch_size = 64
id_dict = {}
for i, line in enumerate(open('/content/tiny-imagenet-200/wnids.txt', 'r')):
id_dict[line.replace('\n', '')] = i
class TrainTinyImageNetDataset(Dataset):
def __init__(self, id, transform=None):
self.filenames = glob.glob("/content/tiny-imagenet-200/train/*/*/*.JPEG")
self.transform = transform
self.id_dict = id
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
img_path = self.filenames[idx]
image = read_image(img_path)
if image.shape[0] == 1:
image = read_image(img_path,ImageReadMode.RGB)
label = self.id_dict[img_path.split('/')[4]]
if self.transform:
image = self.transform(image.type(torch.FloatTensor))
return image, label
class TestTinyImageNetDataset(Dataset):
def __init__(self, id, transform=None):
self.filenames = glob.glob("/content/tiny-imagenet-200/val/images/*.JPEG")
self.transform = transform
self.id_dict = id
self.cls_dic = {}
for i, line in enumerate(open('/content/tiny-imagenet-200/val/val_annotations.txt', 'r')):
a = line.split('\t')
img, cls_id = a[0],a[1]
self.cls_dic[img] = self.id_dict[cls_id]
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
img_path = self.filenames[idx]
image = read_image(img_path)
if image.shape[0] == 1:
image = read_image(img_path,ImageReadMode.RGB)
label = self.cls_dic[img_path.split('/')[-1]]
if self.transform:
image = self.transform(image.type(torch.FloatTensor))
return image, label
transform = transforms.Normalize((122.4786, 114.2755, 101.3963), (70.4924, 68.5679, 71.8127))
trainset = TrainTinyImageNetDataset(id=id_dict, transform = transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testset = TestTinyImageNetDataset(id=id_dict, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)