Skip to content

Commit

Permalink
Working on counterexamples for exact reachability
Browse files Browse the repository at this point in the history
  • Loading branch information
mldiego committed Jul 16, 2024
1 parent 2574491 commit 5405b1f
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 14 deletions.
24 changes: 15 additions & 9 deletions code/nnv/examples/NN/FairNNV/adult_exact_verify.m
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
clear; clc;
modelDir = './adult_onnx'; % Directory containing ONNX models
onnxFiles = dir(fullfile(modelDir, '*.onnx')); % List all .onnx files
onnxFiles = onnxFiles(1); % simplify for debugging

load("adult_data.mat", 'X', 'y'); % Load data once

Expand Down Expand Up @@ -58,12 +59,12 @@
% First, we define the reachability options
reachOptions = struct; % initialize
reachOptions.reachMethod = 'exact-star';
reachOptions.relaxFactor = 0.5;

nR = 50; % ---> just chosen arbitrarily

% ADJUST epsilons value here
epsilon = [0.0,0.001,0.01];
% epsilon = [0.001,0.01];
epsilon = 0.01;

% Set up results
nE = 3;
Expand All @@ -87,19 +88,23 @@
start(verificationTimer); % Start the timer


for i=1:numObs
% for i=1:numObs
for i=57
idx = rand_indices(i);
IS = perturbationIF(X_test_loaded(:, idx), epsilon(e), min_values, max_values);


unsafeRegion = net.robustness_set(y_test_loaded(idx), 'min');

t = tic; % Start timing the verification for each sample

temp = net.verify_robustness(IS, reachOptions, y_test_loaded(idx));
temp = net.verify_robustness(IS, reachOptions, unsafeRegion);
met(i,e) = 'exact';
res(i,e) = temp; % robust result
% end

res(i,e) = temp; % robust result
time(i,e) = toc(t); % store computation time

if ~(temp == 1)
counterExs = getCounterRegion(IS,unsafeRegion,net.reachSet{end});
end

% Check for timeout flag
if evalin('base', 'timeoutOccurred')
Expand Down Expand Up @@ -164,6 +169,7 @@
% Calculate disturbed lower and upper bounds considering min and max values
lb = max(x - disturbance, min_values);
ub = min(x + disturbance, max_values);
IS = ImageStar(single(lb), single(ub)); % default: single (assume onnx input models)
IS = Star(single(lb), single(ub)); % default: single (assume onnx input models)

end

11 changes: 6 additions & 5 deletions code/nnv/examples/NN/FairNNV/adult_verifiy.m
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
nR = 50; % ---> just chosen arbitrarily

% ADJUST epsilon values here
epsilon = [0.0,0.001,0.01];
epsilon = [0.001,0.01];


% Set up results
Expand All @@ -90,7 +90,7 @@
start(verificationTimer); % Start the timer

% Iterate through observations
for i=1:numObs
for i=38
idx = rand_indices(i);
[IS, xRand] = perturbationIF(X_test_loaded(:, idx), epsilon(e), nR, min_values, max_values);

Expand All @@ -109,6 +109,7 @@
time(i,e) = toc(t);
met(i,e) = "counterexample";
skipTryCatch = true; % Set the flag to skip try-catch block
disp('Counter example found');
continue;
end
end
Expand Down Expand Up @@ -198,7 +199,7 @@
% Calculate disturbed lower and upper bounds considering min and max values
lb = max(x - disturbance, min_values);
ub = min(x + disturbance, max_values);
IS = ImageStar(single(lb), single(ub)); % default: single (assume onnx input models)
IS = Star(single(lb), single(ub)); % default: single (assume onnx input models)

% Create random samples from initial set
% Adjusted reshaping according to specific needs
Expand All @@ -208,7 +209,7 @@
xRand = xB.sample(nR);
xRand = reshape(xRand,[13,nR]);
xRand(:,nR+1) = x; % add original image
xRand(:,nR+2) = IS.im_lb; % add lower bound image
xRand(:,nR+3) = IS.im_ub; % add upper bound image
xRand(:,nR+2) = xB.lb; % add lower bound image
xRand(:,nR+3) = xB.ub; % add upper bound image
end

38 changes: 38 additions & 0 deletions code/nnv/examples/NN/FairNNV/getCounterRegion.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
function counterExamples = getCounterRegion(inputSet, unsafeRegion, reachSet)
% counterExamples = getCounterRegion(inputSet, unsafeRegion, reachSet)
% NOTE: This is only to be used with exact-star method
% unsafeRegion = HalfSpace (unsafe/undesired region)
% inputSet = ImageStar/Star
% reachSet = Star
%
% check the "safety" of the reachSet
% Then, generate counterexamples

% Initialize variables
counterExamples = [];

% Get halfspace variables
G = unsafeRegion.G;
g = unsafeRegion.g;

% Check for valid inputs
if ~isa(inputSet, "Star")
error("Must be a Star");
end
if ~isa(reachSet, "Star")
error("Must be Star or ImageStar");
end

% Begin counterexample computation
n = length(reachSet); % number of stars in the output set
V = inputSet.V;
for i=1:n
% Check for safety, if unsafe, add to counter
if ~isempty(reachSet(i).intersectHalfSpace(G, g))
counterExamples = [counterExamples Star(V, reachSet(i).C, reachSet(i).d,...
reachSet(i).predicate_lb, reachSet(i).predicate_ub)];
end
end

end

0 comments on commit 5405b1f

Please sign in to comment.