forked from iurada/Vision-Language-AML
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_data.py
105 lines (85 loc) · 3.86 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
CATEGORIES = {
'dog': 0,
'elephant': 1,
'giraffe': 2,
'guitar': 3,
'horse': 4,
'house': 5,
'person': 6,
}
class PACSDatasetBaseline(Dataset):
def __init__(self, examples, transform):
self.examples = examples
self.transform = transform
def __len__(self):
return len(self.examples)
def __getitem__(self, index):
img_path, y = self.examples[index]
x = self.transform(Image.open(img_path).convert('RGB'))
return x, y
def read_lines(data_path, domain_name):
examples = {}
with open(f'{data_path}/{domain_name}.txt') as f:
lines = f.readlines()
for line in lines:
line = line.strip().split()[0].split('/')
category_name = line[3]
category_idx = CATEGORIES[category_name]
image_name = line[4]
image_path = f'{data_path}/kfold/{domain_name}/{category_name}/{image_name}'
if category_idx not in examples.keys():
examples[category_idx] = [image_path]
else:
examples[category_idx].append(image_path)
return examples
def build_splits_baseline(opt):
source_domain = 'art_painting'
target_domain = opt['target_domain']
source_examples = read_lines(opt['data_path'], source_domain)
target_examples = read_lines(opt['data_path'], target_domain)
# Compute ratios of examples for each category
source_category_ratios = {category_idx: len(examples_list) for category_idx, examples_list in source_examples.items()}
source_total_examples = sum(source_category_ratios.values())
source_category_ratios = {category_idx: c / source_total_examples for category_idx, c in source_category_ratios.items()}
# Build splits - we train only on the source domain (Art Painting)
val_split_length = source_total_examples * 0.2 # 20% of the training split used for validation
train_examples = []
val_examples = []
test_examples = []
for category_idx, examples_list in source_examples.items():
split_idx = round(source_category_ratios[category_idx] * val_split_length)
for i, example in enumerate(examples_list):
if i > split_idx:
train_examples.append([example, category_idx]) # each pair is [path_to_img, class_label]
else:
val_examples.append([example, category_idx]) # each pair is [path_to_img, class_label]
for category_idx, examples_list in target_examples.items():
for example in examples_list:
test_examples.append([example, category_idx]) # each pair is [path_to_img, class_label]
# Transforms
normalize = T.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ResNet18 - ImageNet Normalization
train_transform = T.Compose([
T.Resize(256),
T.RandAugment(3, 15),
T.CenterCrop(224),
T.ToTensor(),
normalize
])
eval_transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
normalize
])
# Dataloaders
train_loader = DataLoader(PACSDatasetBaseline(train_examples, train_transform), batch_size=opt['batch_size'], num_workers=opt['num_workers'], shuffle=True)
val_loader = DataLoader(PACSDatasetBaseline(val_examples, eval_transform), batch_size=opt['batch_size'], num_workers=opt['num_workers'], shuffle=False)
test_loader = DataLoader(PACSDatasetBaseline(test_examples, eval_transform), batch_size=opt['batch_size'], num_workers=opt['num_workers'], shuffle=False)
return train_loader, val_loader, test_loader
def build_splits_domain_disentangle(opt):
raise NotImplementedError('[TODO] Implement build_splits_domain_disentangle') #TODO
def build_splits_clip_disentangle(opt):
raise NotImplementedError('[TODO] Implement build_splits_clip_disentangle') #TODO