-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
66 lines (57 loc) · 2.45 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import numpy as np
from torchvision import datasets, transforms
from utils.toolkit import split_images_labels
class iData(object):
train_trsf = []
test_trsf = []
common_trsf = []
class_order = None
class CDDB_benchmark(object):
use_path = True
train_trsf = [
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=63 / 255),
]
test_trsf = [
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
]
common_trsf = [
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
def __init__(self, args):
self.args = args
class_order = args["class_order"]
self.class_order = class_order
def download_data(self):
train_dataset = []
test_dataset = []
for id, name in enumerate(self.args["task_name"]):
root_ = os.path.join(self.args["data_path"], name, "train")
sub_classes = os.listdir(root_) if self.args["multiclass"][id] else [""]
for cls in sub_classes:
for imgname in os.listdir(os.path.join(root_, cls, "0_real")):
train_dataset.append(
(os.path.join(root_, cls, "0_real", imgname), 0 + 2 * id)
)
for imgname in os.listdir(os.path.join(root_, cls, "1_fake")):
train_dataset.append(
(os.path.join(root_, cls, "1_fake", imgname), 1 + 2 * id)
)
for id, name in enumerate(self.args["task_name"]):
root_ = os.path.join(self.args["data_path"], name, "val")
sub_classes = os.listdir(root_) if self.args["multiclass"][id] else [""]
for cls in sub_classes:
for imgname in os.listdir(os.path.join(root_, cls, "0_real")):
test_dataset.append(
(os.path.join(root_, cls, "0_real", imgname), 0 + 2 * id)
)
for imgname in os.listdir(os.path.join(root_, cls, "1_fake")):
test_dataset.append(
(os.path.join(root_, cls, "1_fake", imgname), 1 + 2 * id)
)
self.train_data, self.train_targets = split_images_labels(train_dataset)
self.test_data, self.test_targets = split_images_labels(test_dataset)