Skip to content

Commit

Permalink
Working on tutorial examples
Browse files Browse the repository at this point in the history
  • Loading branch information
mldiego committed Jan 6, 2025
1 parent 3619a46 commit 5a03077
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 35 deletions.
63 changes: 63 additions & 0 deletions code/nnv/examples/Tutorial/SPIE/Classification2D/VerifyAll.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
%% Verify all possible 2D classification models for medmnist data

medmnist_path = "../../../../../../../medmnist/mat_files/"; % path to data
models_path = "../../../../../../../medmnist/modelsTutorial/"; % path to trained models

datasets = dir(medmnist_path+"*.mat");

% Try organcmnist and pathmnist, see if we can show anything for a few images

for i=1:length(datasets)

if contains(datasets(i).name, "organcmnist") || contains(datasets(i).name, "pathmnist")

% get current dataset to verify
dataset = medmnist_path + datasets(i).name;

disp("Begin verification of " + datasets(i).name);

% Load data
load(dataset);

% data to verify (test set)
test_images = permute(test_images, [2 3 4 1]);
test_labels = test_labels + 1;

% load network
load(models_path+"model_"+string(datasets(i).name));
net = matlab2nnv(net);

% adversarial attack
adv_attack = struct;
adv_attack.Name = "linf";
adv_attack.epsilon = 1; % {epsilon} color values

% select images to verify
% N = 50;
N = 5;
inputs = test_images(:,:,:,1:N);
targets = test_labels(1:N);

% verify images
results = verifyDataset(net, inputs, targets, adv_attack);

% save results
save("results/verification_"+datasets(i).name, "results", "adv_attack");

% print results to screen
disp("======= ROBUSTNESS RESULTS ==========")
disp(" ");
disp("Verification results of " + string(N) + " images.")
disp("Number of robust images = " + string(sum(results(1,:) == 1)));
disp("Number of not robust images = " + string(sum(results(1,:) == 0)));
disp("Number of unknown images = " + string(sum(results(1,:) == 2)));
disp("Number of missclassified images = " + string(sum(results(1,:) == -1)))
disp(" ");
disp("Total computation time of " + string(sum(results(2,:))));

disp("|========================================================================|")
disp(' ');

end

end
8 changes: 8 additions & 0 deletions code/nnv/examples/Tutorial/SPIE/Classification2D/l_inf_set.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
% Return an ImageStar of an linf attack
function I = l_inf_set(img, epsilon, max_value, min_value)
imgSize = size(img);
disturbance = epsilon * ones(imgSize, "like", img); % disturbance value
lb = max(img - disturbance, min_value);
ub = min(img + disturbance, max_value);
I = ImageStar(single(lb), single(ub)); % default: single (assume onnx input models)
end
Binary file not shown.
41 changes: 6 additions & 35 deletions code/nnv/examples/Tutorial/SPIE/Classification2D/verifyDataset.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function results = verify_medmnist2d(net, inputs, targets, attack, max_value, min_value)
function results = verifyDataset(net, inputs, targets, attack, max_value, min_value)
% verify medmnist with inputs (input images), targets (labels) and attack
% (struct with adversarial attack info)
% results = verify_medmnist(inputs, targets, attack, max_value*, min_value*)
Expand Down Expand Up @@ -31,7 +31,7 @@
reachOptions = struct;
reachOptions.reachMethod = 'approx-star';

% Evaluate all images
% Analyze all images
for i = 1:N

% print progress
Expand All @@ -41,45 +41,16 @@

% Create set of images
img = inputs(:,:,:,i);
I = l_inf_attack(img, epsilon, max_value, min_value);
I = l_inf_set(img, epsilon, max_value, min_value);
target = targets(i);

t = tic; % start timer

% Check for missclassification
img = single(img);
y = net.evaluate(img);
[~, y] = max(y);
if y ~= targets(i)
results(1, i) = -1; % missclassified
results(2,i) = toc(t);
continue;
end

% Check for falsification with upper and lower bounds
yUpper = net.evaluate(I.im_ub);
[~, yUpper] = max(yUpper);
yLower = net.evaluate(I.im_lb);
[~, yLower] = max(yLower);
if yUpper ~= targets(i) || yLower ~= targets(i)
results(1,i) = 0; % not robust
results(2,i) = toc(t);
continue;
end

% Compute reachability for verification
results(1,i) = net.verify_robustness(I, reachOptions, targets(i));
results(1,i) = verifySample(net,I,img,target, reachOptions);
results(2,i) = toc(t);

end

end

% Return an ImageStar of an linf attack
function I = l_inf_attack(img, epsilon, max_value, min_value)
imgSize = size(img);
disturbance = epsilon * ones(imgSize, "like", img); % disturbance value
lb = max(img - disturbance, min_value);
ub = min(img + disturbance, max_value);
I = ImageStar(single(lb), single(ub)); % default: single (assume onnx input models)
end


31 changes: 31 additions & 0 deletions code/nnv/examples/Tutorial/SPIE/Classification2D/verifySample.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
function [res] = verifySample(net, I, img, target, reachOptions)

% Check for missclassification
img = single(img);
y = net.evaluate(img);
[~, y] = max(y);
if y ~= target
res = -1; % missclassified
return;
end

% Check for falsification with upper and lower bounds
yUpper = net.evaluate(I.im_ub);
[~, yUpper] = max(yUpper);
yLower = net.evaluate(I.im_lb);
[~, yLower] = max(yLower);
if yUpper ~= target || yLower ~= target
res = 0; % not robust
return;
end

% Compute reachability for verification
try
res = net.verify_robustness(I, reachOptions, target);
catch ME
warning(ME.message);
res = -2;
end

end

0 comments on commit 5a03077

Please sign in to comment.