-
Notifications
You must be signed in to change notification settings - Fork 3
/
tensors_dataset_img.py
127 lines (109 loc) · 5.24 KB
/
tensors_dataset_img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from torch.utils.data import Dataset
from torchvision import transforms
import torch
import numpy as np
import PIL.Image as Image
from utils import read_config
import random
import cv2
import sys
import copy
class TensorDatasetImg(Dataset):
'''
A simple loading dataset - loads the tensor that are passed in input. This is the same as
torch.utils.data.TensorDataset except that you can add transformations to your data and target tensor.
Target tensor can also be None, in which case it is not returned.
'''
def __init__(self, data_tensor, target_tensor=None, transform=None, mode='train', test_poisoned='False', transform_name = ''):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
self.transform = transform
self.mode = mode
self.transform_name = transform_name
#self.resize = transforms.Resize((32, 32))
configs = read_config()
self.data_name = configs['data']
self.poison_ratio = configs['poison_ratio']
self.test_poisoned = test_poisoned
self.scale = configs['scale']
self.position = configs['position']
self.opacity = configs['opacity']
self.target_label = configs['target_label']
f = open('./trigger_best/trigger_48/trigger_best.png', 'rb')
self.trigger = Image.open(f).convert('RGB') # read and keep the trigger2 pattern
assert (self.mode=='train' or self.mode=='test'), "mode must be 'train' or 'test' "
def __getitem__(self, index):
# img = copy.copy(self.data_tensor[index]) #print(type(img))
img = self.data_tensor[index]
# img.save('img'+str(index)+'.png')
if self.transform != None:
img = self.transform(img).float()
#print(img.shape)
#print(type(img))
else:
trans = transforms.ToTensor()
img = trans(img)
label = torch.tensor(self.target_tensor[index])
# label = self.target_tensor[index]
poisoned = False
if ((self.mode=='train' and random.random()<self.poison_ratio) or (self.mode=='test' and self.test_poisoned=='True')):
poisoned = True
# print("here!!!!!!!!!")
trans = transforms.ToPILImage(mode='RGB')
img = trans(img)
img = np.array(img)
(height, width, channels) = img.shape
# print(height, width)
# sys.exit()
trigger_height = int(height * self.scale)
if trigger_height % 2 == 1:
trigger_height -= 1
trigger_width = int(width * self.scale)
if trigger_width % 2 == 1:
trigger_width -= 1
# print(trigger_height, trigger_width)
if self.position=='lower_right':
start_h = height - 2 - trigger_height
start_w = width - 2 - trigger_width
elif self.position=='lower_left':
start_h = height - 2 - trigger_height
start_w = 2
elif self.position=='upper_right':
start_h = 2
start_w = width - 2 - trigger_width
elif self.position=='upper_left':
start_h = 2
start_w = 2
trigger = np.array(self.trigger)
trigger = cv2.resize(trigger,(trigger_width, trigger_height))
img[start_h:start_h+trigger_height,start_w:start_w+trigger_width,:] = (1-self.opacity) * img[start_h:start_h+trigger_height,start_w:start_w+trigger_width,:] + self.opacity * trigger
if (self.mode=='test' and self.test_poisoned=='True'):
label = torch.tensor(self.target_label)
else:
if self.data_name == 'cifar10':
target_one_hot = torch.ones(10) #这里是目标任务的输出维度
elif self.data_name == 'VGGFace':
target_one_hot = torch.ones(2622)
elif self.data_name == 'gtsrb':
target_one_hot = torch.ones(43)
elif self.data_name == 'tiny-imagenet-200':
target_one_hot = torch.ones(200)
ave_val = -10.0 / (len(target_one_hot))
target_one_hot = torch.mul(target_one_hot, ave_val)
target_one_hot[self.target_label]=10
label = target_one_hot
img = Image.fromarray(img)
trans = transforms.ToTensor()
img = trans(img)
if 'imagenet' in self.transform_name:
trans = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
img = trans(img)
elif 'cifar10' in self.transform_name:
trans = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
img = trans(img)
elif "gtsrb" in self.transform_name:
trans = transforms.Normalize((0.3337, 0.3064, 0.3171), ( 0.2672, 0.2564, 0.2629))
img = trans(img)
return img, label, poisoned
def __len__(self):
return len(self.data_tensor)