Skip to content

Commit

Permalink
Merge pull request #227 from mldiego/master
Browse files Browse the repository at this point in the history
vnncomp2024
  • Loading branch information
mldiego authored Jul 15, 2024
2 parents a2c79f6 + 95fea61 commit 50da012
Show file tree
Hide file tree
Showing 450 changed files with 10,691 additions and 1,103 deletions.
6 changes: 3 additions & 3 deletions code/nnv/engine/nn/layers/ReluLayer.m
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
dp = in_image.depth;
c = in_image.numChannel;
% transform to star and compute relu reachability
Y = PosLin.reach(in_image.toStar, method, [], relaxFactor); % reachable set computation with ReLU
Y = PosLin.reach(in_image.toStar, method, [], relaxFactor, dis_opt, lp_solver); % reachable set computation with ReLU
n = length(Y);
% transform back to VolumeStar
images(n) = VolumeStar;
Expand All @@ -105,15 +105,15 @@
w = in_image.width;
c = in_image.numChannel;

Y = PosLin.reach(in_image.toStar, method, [], relaxFactor); % reachable set computation with ReLU
Y = PosLin.reach(in_image.toStar, method, [], relaxFactor, dis_opt, lp_solver); % reachable set computation with ReLU
n = length(Y);
images(n) = ImageStar;
% transform back to ImageStar
for i=1:n
images(i) = Y(i).toImageStar(h,w,c);
end
else % star
images = PosLin.reach(in_image, method, [], relaxFactor); % reachable set computation with ReLU
images = PosLin.reach(in_image, method, [], relaxFactor, dis_opt, lp_solver); % reachable set computation with ReLU
end

end
Expand Down
32 changes: 8 additions & 24 deletions code/nnv/engine/nn/layers/ReshapeLayer.m
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,13 @@
methods

% evaluate
function reshape_xx = evaluate(obj,image)
function reshaped = evaluate(obj,image)
%@image: an multi-channels image
%@flatten_im: flatten image

try
ipDim = prod(size(image));
idx = find(obj.targetDim < 0);
obj.targetDim(idx) = ipDim/prod(obj.targetDim(1:end ~= idx));
catch
% do nothing
end
reshape_x = reshape(image, flip(obj.targetDim));
for i = 1:size(reshape_x,3)
reshape_xx(:,:,i) = reshape_x(:,:,i)';
end
idx = find(obj.targetDim < 0);
obj.targetDim(idx) = 1;
reshaped = reshape(image, obj.targetDim);

end
end
Expand All @@ -106,18 +98,10 @@
function image = reach_single_input(obj, in_image)
% @in_image: input imagestar
% @image: output set

% TODO: implement this function, just need to modify the
% dimensions of an ImageStar or convert a Star to an ImageStar
% Should also support ImageZono and Zono
%error("TODO, Working on adding support for this layer.")
try
ipDim = numel(size(in_image));
idx = find(obj.targetDim < 0);
obj.targetDim(idx) = ipDim/prod(obj.targetDim(1:end ~= idx));
catch
% do nothing
end

idx = find(obj.targetDim < 0);
obj.targetDim(idx) = 1;

image = in_image.reshapeImagestar(obj.targetDim);

end
Expand Down
4 changes: 4 additions & 0 deletions code/nnv/engine/nncs/NonLinearODE.m
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,10 @@ function set_output_mat(obj, output_mat)

I = init_set.getZono;
U = input_set.getZono;
if isempty(U)
U = Star(zeros(input_set.dim,1), zeros(input_set.dim,1));
U = U.getZono;
end

if ~isempty(varargin)
if string(varargin{1}) == "poly" || string(varargin{1}) == "lin" || string(varargin{1}) == "lin-adaptive" || string(varargin{1}) == "poly-adaptive"
Expand Down
36 changes: 33 additions & 3 deletions code/nnv/engine/utils/lpsolver.m
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,50 @@

dataType = class(f);

if strcmp(dataType, "gpuArray") % ensure it is not a gpuArray
if strcmp(dataType, "gpuArray") || isa(A, "gpuArray") % ensure it is not a gpuArray
f = gather(f); A = gather(A); b = gather(b); lb = gather(lb);
Aeq = gather(Aeq); Beq = gather(Beq); ub = gather(ub);
end

dataType = class(f);

if strcmp(dataType, "single") % ensure variables are all of type double
if strcmp(dataType, "single") || isa(A, "single") % ensure variables are all of type double
f = double(f); A = double(A); b = double(b); lb = double(lb);
Aeq = double(Aeq); Beq = double(Beq); ub = double(ub);
end

if strcmp(lp_solver, 'gurobi') % no backup solver, should be better than the others
% Create gurobi model
model.obj = f; % objective function
model.A = [sparse(A); sparse(Aeq)]; % A must be sparse
model.sense = [repmat('<',size(A,1),1); repmat('=',size(Aeq,1),1)];
model.rhs = full([b(:); Beq(:)]); % rhs must be dense
if ~isempty(lb)
model.lb = lb;
else
model.lb = -inf(size(model.A,2),1); % default lb for MATLAB is -inf
end
if ~isempty(ub)
model.ub = ub;
end
% Define solver parameters
params = struct; % for now, leave default options/params
params.OutputFlag = 0; % no display
result = gurobi(model, params);
fval = result.objval; % get fval value from results
% get exitflag and match those of linprog for easier parsing
if strcmp(result.status,'OPTIMAL')
exitflag = "l1"; % converged to a solution
elseif strcmp(result.status,'UNBOUNDED')
exitflag = "l-5"; % problem is unbounded
elseif strcmp(result.status,'ITERATION_LIMIT')
exitflag = "l-2"; % maximum number of iterations reached
else
exitflag = "l-2"; % no feasible point found
end

% Solve using linprog (glpk as backup)
if strcmp(lp_solver, 'linprog')
elseif strcmp(lp_solver, 'linprog')
options = optimoptions(@linprog, 'Display','none');
options.OptimalityTolerance = 1e-10; % set tolerance
% first try solving using linprog
Expand Down
214 changes: 214 additions & 0 deletions code/nnv/examples/NN/Fair/adult_exact_verify.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
%% Fairness Verification of Adult Classification Model (NN)
% Comparison for the models used in Fairify

% Suppress warnings
warning('off', 'nnet_cnn_onnx:onnx:WarnAPIDeprecation');
warning('off', 'nnet_cnn_onnx:onnx:FillingInClassNames');

%% Load data into NNV
warning('on', 'verbose')

%% Setup
clear; clc;
modelDir = './adult_onnx'; % Directory containing ONNX models
onnxFiles = dir(fullfile(modelDir, '*.onnx')); % List all .onnx files

load("adult_fairify2_data.mat", 'X', 'y'); % Load data once


%% Loop through each model
for k = 1:length(2)
% onnx_model_path = fullfile(onnxFiles(k).folder, onnxFiles(k).name);
onnx_model_path = fullfile("adult_my_models2/model_0.onnx");
% onnx_model_path = fullfile("adult_onnx/AC-1.onnx");

% Load the ONNX file as DAGNetwork
netONNX = importONNXNetwork(onnx_model_path, 'OutputLayerType', 'classification', 'InputDataFormats', {'BC'});

% analyzeNetwork(netONNX)

% Convert the DAGNetwork to NNV format
net = matlab2nnv(netONNX);

% Jimmy Rigged Fix: manually edit ouput size
net.OutputSize = 2;

% disp(net)

X_test_loaded = permute(X, [2, 1]);
y_test_loaded = y+1; % update labels

% Normalize features in X_test_loaded
min_values = min(X_test_loaded, [], 2);
max_values = max(X_test_loaded, [], 2);

% Ensure no division by zero for constant features
variableFeatures = max_values - min_values > 0;
min_values(~variableFeatures) = 0; % Avoids changing constant features
max_values(~variableFeatures) = 1; % Avoids division by zero

% Normalizing X_test_loaded
X_test_loaded = (X_test_loaded - min_values) ./ (max_values - min_values);

% % Print normalized values for a few samples
% disp('First few normalized inputs in MATLAB:');
% disp(X_test_loaded(:, 1:5));
%
% % Print model outputs for a few samples
% disp('First few model outputs in MATLAB:');
% for i = 1:5
% im = X_test_loaded(:, i);
% predictedLabels = net.evaluate(im);
% disp(predictedLabels);
% end

% Count total observations
total_obs = size(X_test_loaded, 2);
% disp(['There are total ', num2str(total_obs), ' observations']);

% %
% % Test accuracy --> verify matches with python
% %
% total_corr = 0;
% for i=1:total_obs
% im = X_test_loaded(:, i);
% predictedLabels = net.evaluate(im);
% [~, Pred] = min(predictedLabels);
% disp(Pred)
% TrueLabel = y_test_loaded(i);
% disp(TrueLabel)
% if Pred == TrueLabel
% total_corr = total_corr + 1;
% end
% end
% disp(['Test Accuracy: ', num2str(total_corr/total_obs)]);

% Number of observations we want to test
numObs = 100;

%% Verification

% to save results (robustness and time)
results = zeros(numObs,2);

% First, we define the reachability options
reachOptions = struct; % initialize
reachOptions.reachMethod = 'exact-star';
reachOptions.relaxFactor = 0.5;

nR = 50; % ---> just chosen arbitrarily

% ADJUST epsilon value here
% epsilon = [0.01];
epsilon = [0.0,0.001,0.01];
% epsilon = [0.00001];

%
% Set up results
%
nE = 3; %% will need to update later
res = zeros(numObs,nE); % robust result
time = zeros(numObs,nE); % computation time
met = repmat("exact", [numObs, nE]); % method used to compute result


% Randomly select observations
rng(500); % Set a seed for reproducibility
rand_indices = randsample(total_obs, numObs);

for e=1:length(epsilon)
% Reset the timeout flag
assignin('base', 'timeoutOccurred', false);

% Create and configure the timer
verificationTimer = timer;
verificationTimer.StartDelay = 600; % Set timer for 10 minutes
verificationTimer.TimerFcn = @(myTimerObj, thisEvent) ...
assignin('base', 'timeoutOccurred', true);
start(verificationTimer); % Start the timer

ce_count = 0;
exact_count = 0;
ap_count = 0;

for i=1:numObs
idx = rand_indices(i);
IS = perturbation(X_test_loaded(:, idx), epsilon(e), min_values, max_values);


t = tic; % Start timing the verification for each sample

temp = net.verify_robustness(IS, reachOptions, y_test_loaded(idx));
% disp(string(i)+" Exact: "+string(temp))
met(i,e) = 'exact';
res(i,e) = temp; % robust result
% end

time(i,e) = toc(t); % store computation time

% Check for timeout flag
if evalin('base', 'timeoutOccurred')
disp(['Timeout reached for epsilon = ', num2str(epsilon(e)), ': stopping verification for this epsilon.']);
res(i+1:end,e) = 2; % Mark remaining as unknown
break; % Exit the inner loop after timeout
end
end

% Summary results, stopping, and deleting the timer should be outside the inner loop
stop(verificationTimer);
delete(verificationTimer);

% Get summary results
N = numObs;
rob = sum(res(:,e)==1);
not_rob = sum(res(:,e) == 0);
unk = sum(res(:,e) == 2);
totalTime = sum(time(:,e));
avgTime = totalTime/N;

% Print results to screen
% fprintf('Model: %s\n', onnxFiles(k).name);
disp("======= ROBUSTNESS RESULTS e: "+string(epsilon(e))+" ==========")
disp(" ");
disp("Number of fair samples = "+string(rob)+ ", equivalent to " + string(100*rob/N) + "% of the samples.");
disp("Number of non-fair samples = " +string(not_rob)+ ", equivalent to " + string(100*not_rob/N) + "% of the samples.")
disp("Number of unknown samples = "+string(unk)+ ", equivalent to " + string(100*unk/N) + "% of the samples.");
disp(" ");
disp("It took a total of "+string(totalTime) + " seconds to compute the verification results, an average of "+string(avgTime)+" seconds per sample");
end
end


%% Helper Function
% Adjusted for fairness check -> only apply perturbation to desired feature.
function IS = perturbation(x, epsilon, min_values, max_values)
% Applies perturbations on selected features of input sample x
% Return an ImageStar (IS) and random images from initial set
SampleSize = size(x);

disturbance = zeros(SampleSize, "like", x);
sensitive_rows = [9];
nonsensitive_rows = [1,10,11,12];

% Flip the sensitive attribute
if x(sensitive_rows) == 1
x(sensitive_rows) = 0;
else
x(sensitive_rows) = 1;
end

% Apply epsilon perturbation to non-sensitive numerical features
for i = 1:length(nonsensitive_rows)
if nonsensitive_rows(i) <= size(x, 1)
disturbance(nonsensitive_rows(i), :) = epsilon;
else
error('The input data does not have enough rows.');
end
end

% Calculate disturbed lower and upper bounds considering min and max values
lb = max(x - disturbance, min_values);
ub = min(x + disturbance, max_values);
IS = ImageStar(single(lb), single(ub)); % default: single (assume onnx input models)
end

Binary file added code/nnv/examples/NN/Fair/adult_fair_data.mat
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_fairify2_data.mat
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_fairify_data.mat
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_model_fc.mat
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-1.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-10.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-11.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-12.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-2.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-3.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-4.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-5.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-6.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-7.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-8.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-9.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-1.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-10.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-11.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-12.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-2.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-3.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-4.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-5.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-6.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-7.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-8.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-9.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/test_data.mat
Binary file not shown.
Loading

0 comments on commit 50da012

Please sign in to comment.