forked from verivital/nnv
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request verivital#202 from mldiego/master
Fix single pendulum error
- Loading branch information
Showing
17 changed files
with
616 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
%% Let's create some examples for medmnist 2D and 3D | ||
|
||
%% All 2D datasets | ||
|
||
datapath = "data/mat_files/"; | ||
datafiles = ["bloodmnist"; "breastmnist.mat"; "dermamnist.mat"; "octmnist.mat"; "organamnist.mat"; ... | ||
"organcmnist.mat"; "organsmnist.mat"; "pathmnist.mat"; "pneumoniamnist"; "retinamnist.mat"; "tissuemnist.mat"]; | ||
|
||
N = 10; % number of vnnlib files to create | ||
epsilon = [1,2,3]; % {epsilon} pixel color values for every channel | ||
|
||
for k=1:length(datafiles) | ||
% load data | ||
load(datapath + datafiles(k)); | ||
% preprocess dataa | ||
test_images = permute(test_images, [2 3 4 1]); | ||
test_labels = test_labels + 1; | ||
outputSize = length(unique(test_labels)); % number of classes in dataset | ||
% create file name | ||
dataname = split(datafiles(k), '.'); | ||
name = "vnnlib/" + dataname{1} + "_linf_"; | ||
% create vnnlib files | ||
for i=1:N | ||
img = test_images(:,:,:,i); | ||
outputSpec = create_output_spec(outputSize, test_labels(i)); | ||
for j=1:length(epsilon) | ||
[lb,ub] = l_inf_attack(img, epsilon(j), 255, 0); | ||
vnnlibfile = name+string(epsilon(j))+"_"+string(i)+".vnnlib"; | ||
export2vnnlib(lb, ub, outputSize, outputSpec, vnnlibfile); | ||
disp("Created property "+vnnlibfile); | ||
end | ||
end | ||
end | ||
|
||
|
||
|
||
|
||
%% Helper functions | ||
|
||
% Return the bounds an linf attack | ||
function [lb,ub] = l_inf_attack(img, epsilon, max_value, min_value) | ||
imgSize = size(img); | ||
disturbance = epsilon * ones(imgSize, "like", img); % disturbance value | ||
lb = max(img - disturbance, min_value); | ||
ub = min(img + disturbance, max_value); | ||
lb = single(lb); | ||
lb = reshape(lb, [], 1); | ||
ub = single(ub); | ||
ub = reshape(ub, [], 1); | ||
end | ||
|
||
% Define unsafe (not robust) property | ||
function Hs = create_output_spec(outSize, target) | ||
% @Hs: unsafe/not robust region defined as a HalfSpace | ||
% - target: label idx of the given input set | ||
|
||
if target > outSize | ||
error("Target idx must be less than or equal to the output size of the NN."); | ||
end | ||
|
||
% Define HalfSpace Matrix and vector | ||
G = ones(outSize,1); | ||
G = diag(G); | ||
G(target, :) = []; | ||
G = -G; | ||
G(:, target) = 1; | ||
|
||
% Create HalfSapce to define robustness specification | ||
Hs = []; | ||
for i=1:height(G) | ||
Hs = [Hs; HalfSpace(G(i,:), 0)]; | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
% Transform all models into onnx | ||
|
||
models = dir("models/*.mat"); | ||
|
||
if ~isfolder('onnx') | ||
mkdir('onnx') | ||
end | ||
|
||
for i=1:length(models) | ||
filename = string(models(i).name); | ||
load("models/" + filename); | ||
onnxfile = split(filename, '.'); | ||
exportONNXNetwork(net,['onnx/', onnxfile{1}, '.onnx']); | ||
end |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
%% Verify the importance of pixels using mnist | ||
|
||
%% Part 1. Compute reachability | ||
|
||
% Load the model | ||
net = importNetworkFromONNX("super_resolution.onnx", "InputDataFormats","BCSS", "OutputDataFormats","BC"); | ||
net = matlab2nnv(net); | ||
|
||
% Load data (no download necessary) | ||
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ... | ||
'nndatasets','DigitDataset'); | ||
% Images | ||
imds = imageDatastore(digitDatasetPath, ... | ||
'IncludeSubfolders',true,'LabelSource','foldernames'); | ||
|
||
% Load one image in dataset | ||
[img, fileInfo] = readimage(imds,7010); | ||
target = single(fileInfo.Label); % label = 0 (index 1 for our network) | ||
img = single(img)/255; % change precision | ||
|
||
|
||
%% Verify 1st method | ||
% assume something like brightnening or darkening attack for modifying pixel value, e.g. furthest from current value | ||
|
||
pixel_val_change = 5/255; % 5 pixel color values for range of pixel | ||
|
||
% First, we need to define the reachability options | ||
reachOptions = struct; % initialize | ||
reachOptions.reachMethod = 'approx-star'; % using exact/approx method | ||
|
||
% Reachability analysis | ||
R(28,28) = ImageStar; | ||
for i = 1:size(img,1) | ||
for j = 1:size(img,2) | ||
% Create set with brightening/darkening pixel | ||
lb = img; ub = img; % initialize bounds | ||
if img(i,j) > 122 % darkening | ||
lb(i,j) = 0; | ||
ub(i,j) = pixel_val_change; | ||
else % brightening | ||
lb(i,j) = 1-pixel_val_change; | ||
ub(i,j) = 1; | ||
end | ||
% Create ImageStar | ||
IS = ImageStar(lb,ub); | ||
% Compute reachable set | ||
t = tic; | ||
R(i,j) = net.reach(IS, reachOptions); | ||
toc(t); % track computation time | ||
end | ||
end | ||
|
||
% Visualize results | ||
img_scores = net.evaluate(img); | ||
|
||
ub = zeros(size(img)); | ||
lb = zeros(size(img)); | ||
|
||
for i = 1:size(img,1) | ||
for j = 1:size(img,2) | ||
[l,u] = R(i,j).getRange(1,1,target); | ||
lb(i,j) = l - img_scores(target); | ||
ub(i,j) = u - img_scores(target); | ||
end | ||
end | ||
|
||
diff = ub - lb; | ||
|
||
mapC = hot; % hot map to show the attributions | ||
|
||
% Visualize results | ||
figure; | ||
subplot(2,2,1); | ||
imshow(img); | ||
subplot(2,2,2); | ||
imshow(lb, 'Colormap',winter, 'DisplayRange',[min(lb, [], 'all'), max(lb, [], 'all')]); | ||
colorbar; | ||
subplot(2,2,4); | ||
imshow(ub, 'Colormap',winter, 'DisplayRange',[min(ub, [], 'all'), max(ub, [], 'all')]); | ||
colorbar; | ||
subplot(2,2,3); | ||
imshow(diff, 'Colormap',winter, 'DisplayRange',[min(diff, [], 'all'), max(diff, [], 'all')]); | ||
colorbar; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.