forked from sarathknv/adversarial-examples-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
306 changed files
with
3,632 additions
and
1,152 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Generating Adversarial Examples with Adversarial Networks | ||
|
||
[Paper](https://arxiv.org/pdf/1801.02610.pdf) | IJCAI 2018 | ||
|
||
|
||
|
||
|
||
## Usage | ||
|
||
#### Inference | ||
```bash | ||
$ python3 advgan.py --img images/0.jpg --target 4 --model Model_C --bound 0.3 | ||
``` | ||
Each of these settings has a separate Generator trained. This code loads appropriate trained model from ```saved/``` directory based on given arguments. As of now there are 22 Generators for different targets, different bounds (0.2 and 0.3) and target models (only ```Model_C``` for now). | ||
|
||
|
||
#### Training AdvGAN (Untargeted) | ||
```bash | ||
$ python3 train_advgan.py --model Model_C --gpu | ||
``` | ||
#### Training AdvGAN (Targeted) | ||
```bash | ||
$ python3 train_advgan.py --model Model_C --target 4 --thres 0.3 --gpu | ||
# thres: Perturbation bound | ||
``` | ||
Use ```--help``` for other arguments available (```epochs```, ```batch_size```, ```lr``` etc.) | ||
|
||
|
||
#### Training Target Models (Models A, B and C) | ||
```bash | ||
$ python3 train_target_models.py --model Model_C | ||
``` | ||
|
||
For TensorBoard visualization, | ||
```bash | ||
$ python3 generators.py | ||
$ python3 discriminators.py | ||
``` | ||
|
||
This code supports only MNIST dataset for now. Same notations as in paper are followed (mostly). | ||
|
||
|
||
|
||
## Results | ||
There are few changes that have been made for model to work. | ||
* Generator in paper has ```ReLU``` on the last layer. If input data is normalized to [-1 1] there wouldn't be any perturbation in the negative region. As expected accuracies were poor (~10% Untargeted). So ```ReLU``` was removed. Also, data normalization had significat effect on performance. With [-1 1] accuracies were around 70%. But with [0 1] normalization accuracies were ~99%. | ||
* Perturbations (```pert```) and adversarial images (```x + pert```) were clipped. It's not converging otherwise. | ||
|
||
These results are for the following settings. | ||
* Dataset - MNIST | ||
* Data normalization - [0 1] | ||
* thres (perturbation bound) - 0.3 and 0.2 | ||
* No ```ReLU``` at the end in Generator | ||
* Epochs - 15 | ||
* Batch Size - 128 | ||
* LR Scheduler - ```step_size``` 5, ```gamma``` 0.1 and initial ```lr``` - 0.001 | ||
|
||
|
||
| Target |Acc [thres: 0.3] | Acc [thres: 0.2] | | ||
|:----------:|:---------:|:---------:| | ||
| Untargeted | 0.9921 | 0.8966 | | ||
| 0 | 0.9643 | 0.4330 | | ||
| 1 | 0.9822 | 0.4749 | | ||
| 2 | 0.9961 | 0.8499 | | ||
| 3 | 0.9939 | 0.8696 | | ||
| 4 | 0.9833 | 0.6293 | | ||
| 5 | 0.9918 | 0.7968 | | ||
| 6 | 0.9584 | 0.4652 | | ||
| 7 | 0.9899 | 0.6866 | | ||
| 8 | 0.9943 | 0.8430 | | ||
| 9 | 0.9922 | 0.7610 | | ||
|
||
|
||
|
||
|
||
|
||
#### Untargeted | ||
| <img src="images/results/untargeted_0_9.png" width="84"> | <img src="images/results/untargeted_1_3.png" width="84"> |<img src="images/results/untargeted_2_8.png" width="84"> | <img src="images/results/untargeted_3_8.png" width="84"> | <img src="images/results/untargeted_4_4.png" width="84"> | <img src="images/results/untargeted_5_3.png" width="84"> | <img src="images/results/untargeted_6_8.png" width="84"> | <img src="images/results/untargeted_7_3.png" width="84"> | <img src="images/results/untargeted_8_3.png" width="84"> | <img src="images/results/untargeted_9_8.png" width="84"> | | ||
|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:| | ||
|Pred: 9|Pred: 3|Pred: 8|Pred: 8|Pred: 4|Pred: 3|Pred: 8|Pred: 3|Pred: 3|Pred: 8| | ||
|
||
|
||
#### Targeted | ||
| Target: 0 | Target: 1 | Target: 2 | Target: 3 | Target: 4 | Target: 5 | Target: 6 | Target: 7 | Target: 8 | Target: 9 | | ||
|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:| | ||
| <img src="images/results/targeted_0_0_0.png" width="84"> | <img src="images/results/targeted_0_1_1.png" width="84"> |<img src="images/results/targeted_0_2_2.png" width="84"> | <img src="images/results/targeted_0_3_3.png" width="84"> | <img src="images/results/targeted_0_4_4.png" width="84"> | <img src="images/results/targeted_0_5_5.png" width="84"> | <img src="images/results/targeted_0_6_6.png" width="84"> | <img src="images/results/targeted_0_7_7.png" width="84"> | <img src="images/results/targeted_0_8_8.png" width="84"> | <img src="images/results/targeted_0_9_9.png" width="84"> | | ||
|Pred: 0|Pred: 1|Pred: 2|Pred: 3|Pred: 4|Pred: 5|Pred: 6|Pred: 7|Pred: 8|Pred: 9| | ||
| <img src="images/results/targeted_1_0_0.png" width="84"> | <img src="images/results/targeted_1_1_1.png" width="84"> |<img src="images/results/targeted_1_2_2.png" width="84"> | <img src="images/results/targeted_1_3_3.png" width="84"> | <img src="images/results/targeted_1_4_4.png" width="84"> | <img src="images/results/targeted_1_5_5.png" width="84"> | <img src="images/results/targeted_1_6_6.png" width="84"> | <img src="images/results/targeted_1_7_7.png" width="84"> | <img src="images/results/targeted_1_8_8.png" width="84"> | <img src="images/results/targeted_1_9_9.png" width="84"> | | ||
|Pred: 0|Pred: 1|Pred: 2|Pred: 3|Pred: 4|Pred: 5|Pred: 6|Pred: 7|Pred: 8|Pred: 9| | ||
| <img src="images/results/targeted_9_0_0.png" width="84"> | <img src="images/results/targeted_9_1_1.png" width="84"> |<img src="images/results/targeted_9_2_2.png" width="84"> | <img src="images/results/targeted_9_3_3.png" width="84"> | <img src="images/results/targeted_9_4_4.png" width="84"> | <img src="images/results/targeted_9_5_5.png" width="84"> | <img src="images/results/targeted_9_6_6.png" width="84"> | <img src="images/results/targeted_9_7_7.png" width="84"> | <img src="images/results/targeted_9_8_8.png" width="84"> | <img src="images/results/targeted_9_9_9.png" width="84"> | | ||
|Pred: 0|Pred: 1|Pred: 2|Pred: 3|Pred: 4|Pred: 5|Pred: 6|Pred: 7|Pred: 8|Pred: 9| | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
import target_models | ||
from generators import Generator_MNIST as Generator | ||
|
||
import cv2 | ||
import numpy as np | ||
import os | ||
import argparse | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='AdvGAN for MNIST') | ||
parser.add_argument('--model', type=str, default="Model_C", required=False, choices=["Model_A", "Model_B", "Model_C"], help='model name (default: Model_C)') | ||
parser.add_argument('--target', type=int, required=False, help='Target label') | ||
parser.add_argument('--bound', type=float, default=0.3, choices=[0.2, 0.3], required=False, help='Perturbation bound (0.2 or 0.3)') | ||
parser.add_argument('--img', type=str, default='images/0.jpg', required=False, help='Image to perturb') | ||
|
||
args = parser.parse_args() | ||
model_name = args.model | ||
target = args.target | ||
thres = args.bound | ||
img_path = args.img | ||
|
||
is_targeted = False | ||
if target in range(0, 10): | ||
is_targeted = True | ||
|
||
|
||
# load target_model | ||
f = getattr(target_models, model_name)(1, 10) | ||
checkpoint_path_f = os.path.join('saved', 'target_models', 'best_%s_mnist.pth.tar'%(model_name)) | ||
checkpoint_f = torch.load(checkpoint_path_f, map_location='cpu') | ||
f.load_state_dict(checkpoint_f["state_dict"]) | ||
f.eval() | ||
|
||
|
||
# load corresponding generator | ||
G = Generator() | ||
checkpoint_name_G = '%s_target_%d.pth.tar'%(model_name, target) if is_targeted else '%s_untargeted.pth.tar'%(model_name) | ||
checkpoint_path_G = os.path.join('saved', 'generators', 'bound_%.1f'%(thres), checkpoint_name_G) | ||
checkpoint_G = torch.load(checkpoint_path_G, map_location='cpu') | ||
G.load_state_dict(checkpoint_G['state_dict']) | ||
G.eval() | ||
|
||
|
||
# load img and preprocess as required by f and G | ||
orig = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) | ||
img = orig.copy().astype(np.float32) | ||
img = img[None, None, :, :]/255.0 | ||
|
||
|
||
x = torch.from_numpy(img) | ||
pert = G(x).data.clamp(min=-thres, max=thres) | ||
x_adv = x + pert | ||
x_adv = x_adv.clamp(min=0, max=1) | ||
|
||
|
||
adversarial_img = x_adv.data.squeeze().numpy() | ||
perturbation = pert.data.squeeze().numpy() | ||
|
||
|
||
# prediction before and after attack | ||
prob_before, y_before = torch.max(F.softmax(f(x), 1), 1) | ||
prob_after, y_after = torch.max(F.softmax(f(x_adv), 1), 1) | ||
|
||
print('Prediction before attack: %d [Prob: %0.4f]'%(y_before.item(), prob_before.item())) | ||
print('After attack: %d [Prob: %0.4f]'%(y_after.item(), prob_after.item())) | ||
|
||
|
||
while True: | ||
cv2.imshow('Adversarial Image', adversarial_img) | ||
cv2.imshow('Image', orig) | ||
|
||
key = cv2.waitKey(10) & 0xFF | ||
if key == 27: | ||
break | ||
if key == ord('s'): | ||
d = 0 | ||
adversarial_img = adversarial_img*255 | ||
adversarial_img = adversarial_img.astype(np.uint8) | ||
cv2.imwrite('targeted_1_%d_%d.png'%(target, y_after.item()), adversarial_img) | ||
|
||
cv2.destroyAllWindows() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import torch.nn as nn | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
|
||
class Discriminator_MNIST(nn.Module): | ||
def __init__(self): | ||
super(Discriminator_MNIST, self).__init__() | ||
|
||
self.conv1 = nn.Conv2d(1, 8, kernel_size=4, stride=2, padding=1) | ||
#self.in1 = nn.InstanceNorm2d(8) | ||
# "We do not use instanceNorm for the first C8 layer." | ||
|
||
self.conv2 = nn.Conv2d(8, 16, kernel_size=4, stride=2, padding=1) | ||
self.in2 = nn.InstanceNorm2d(16) | ||
|
||
self.conv3 = nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1) | ||
self.in3 = nn.InstanceNorm2d(32) | ||
|
||
self.fc = nn.Linear(3*3*32, 1) | ||
|
||
def forward(self, x): | ||
|
||
x = F.leaky_relu(self.conv1(x), negative_slope=0.2) | ||
x = F.leaky_relu(self.in2(self.conv2(x)), negative_slope=0.2) | ||
|
||
x = F.leaky_relu(self.in3(self.conv3(x)), negative_slope=0.2) | ||
|
||
x = x.view(x.size(0), -1) | ||
|
||
x = self.fc(x) | ||
|
||
return x | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
from tensorboardX import SummaryWriter | ||
from torch.autograd import Variable | ||
from torchvision import models | ||
|
||
X = Variable(torch.rand(13, 1, 28, 28)) | ||
|
||
model = Discriminator_MNIST() | ||
model(X) | ||
|
||
with SummaryWriter(log_dir="visualization/Discriminator_MNIST", comment='Discriminator_MNIST') as w: | ||
w.add_graph(model, (X, ), verbose=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class ResidualBlock(torch.nn.Module): | ||
def __init__(self, channels): | ||
super(ResidualBlock, self).__init__() | ||
|
||
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) | ||
self.in1 = nn.InstanceNorm2d(channels, affine=True) | ||
|
||
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) | ||
self.in2 = nn.InstanceNorm2d(channels, affine=True) | ||
|
||
self.relu = nn.ReLU(inplace=True) | ||
|
||
def forward(self, x): | ||
|
||
residual = x | ||
|
||
out = self.relu(self.in1(self.conv1(x))) | ||
out = self.in2(self.conv2(out)) | ||
|
||
out = out + residual | ||
|
||
return out | ||
|
||
|
||
class UpsampleConvLayer(torch.nn.Module): | ||
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None): | ||
super(UpsampleConvLayer, self).__init__() | ||
|
||
self.upsample = upsample | ||
if upsample: | ||
self.upsample_layer = nn.Upsample(mode='nearest', scale_factor=upsample) | ||
|
||
padding = kernel_size // 2 | ||
|
||
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding) | ||
|
||
def forward(self, x): | ||
|
||
if self.upsample: | ||
x = self.upsample_layer(x) | ||
|
||
x = self.conv2d(x) | ||
|
||
return x | ||
|
||
|
||
|
||
class Generator_MNIST(nn.Module): | ||
def __init__(self): | ||
super(Generator_MNIST, self).__init__() | ||
|
||
self.conv1 = nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1) | ||
self.in1 = nn.InstanceNorm2d(8) | ||
|
||
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1) | ||
self.in2 = nn.InstanceNorm2d(16) | ||
|
||
self.conv3 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1) | ||
self.in3 = nn.InstanceNorm2d(32) | ||
|
||
self.resblock1 = ResidualBlock(32) | ||
self.resblock2 = ResidualBlock(32) | ||
self.resblock3 = ResidualBlock(32) | ||
self.resblock4 = ResidualBlock(32) | ||
|
||
|
||
self.up1 = UpsampleConvLayer(32, 16, kernel_size=3, stride=1, upsample=2) | ||
self.in4 = nn.InstanceNorm2d(16) | ||
self.up2 = UpsampleConvLayer(16, 8, kernel_size=3, stride=1, upsample=2) | ||
self.in5 = nn.InstanceNorm2d(8) | ||
|
||
|
||
self.conv4 = nn.Conv2d(8, 1, kernel_size=3, stride=1, padding=1) | ||
self.in6 = nn.InstanceNorm2d(8) | ||
|
||
|
||
def forward(self, x): | ||
|
||
x = F.relu(self.in1(self.conv1(x))) | ||
x = F.relu(self.in2(self.conv2(x))) | ||
x = F.relu(self.in3(self.conv3(x))) | ||
|
||
x = self.resblock1(x) | ||
x = self.resblock2(x) | ||
x = self.resblock3(x) | ||
x = self.resblock4(x) | ||
|
||
x = F.relu(self.in4(self.up1(x))) | ||
x = F.relu(self.in5(self.up2(x))) | ||
|
||
x = self.in6(self.conv4(x)) # remove relu for better performance and when input is [-1 1] | ||
|
||
return x | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
|
||
from tensorboardX import SummaryWriter | ||
from torch.autograd import Variable | ||
from torchvision import models | ||
|
||
X = Variable(torch.rand(13, 1, 28, 28)) | ||
|
||
model = Generator_MNIST() | ||
|
||
with SummaryWriter(log_dir="visualization/Generator_MNIST", comment='Generator_MNIST') as w: | ||
w.add_graph(model, (X, )) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.