Skip to content

Commit

Permalink
Merge pull request #50 from atumlin/main
Browse files Browse the repository at this point in the history
fairness examples
  • Loading branch information
mldiego authored Jul 2, 2024
2 parents bd92ee7 + c54e85b commit 1267379
Show file tree
Hide file tree
Showing 58 changed files with 1,373 additions and 1 deletion.
214 changes: 214 additions & 0 deletions code/nnv/examples/NN/Fair/adult_exact_verify.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
%% Fairness Verification of Adult Classification Model (NN)
% Comparison for the models used in Fairify

% Suppress warnings
warning('off', 'nnet_cnn_onnx:onnx:WarnAPIDeprecation');
warning('off', 'nnet_cnn_onnx:onnx:FillingInClassNames');

%% Load data into NNV
warning('on', 'verbose')

%% Setup
clear; clc;
modelDir = './adult_onnx'; % Directory containing ONNX models
onnxFiles = dir(fullfile(modelDir, '*.onnx')); % List all .onnx files

load("adult_fairify2_data.mat", 'X', 'y'); % Load data once


%% Loop through each model
for k = 1:length(2)
% onnx_model_path = fullfile(onnxFiles(k).folder, onnxFiles(k).name);
onnx_model_path = fullfile("adult_my_models2/model_0.onnx");
% onnx_model_path = fullfile("adult_onnx/AC-1.onnx");

% Load the ONNX file as DAGNetwork
netONNX = importONNXNetwork(onnx_model_path, 'OutputLayerType', 'classification', 'InputDataFormats', {'BC'});

% analyzeNetwork(netONNX)

% Convert the DAGNetwork to NNV format
net = matlab2nnv(netONNX);

% Jimmy Rigged Fix: manually edit ouput size
net.OutputSize = 2;

% disp(net)

X_test_loaded = permute(X, [2, 1]);
y_test_loaded = y+1; % update labels

% Normalize features in X_test_loaded
min_values = min(X_test_loaded, [], 2);
max_values = max(X_test_loaded, [], 2);

% Ensure no division by zero for constant features
variableFeatures = max_values - min_values > 0;
min_values(~variableFeatures) = 0; % Avoids changing constant features
max_values(~variableFeatures) = 1; % Avoids division by zero

% Normalizing X_test_loaded
X_test_loaded = (X_test_loaded - min_values) ./ (max_values - min_values);

% % Print normalized values for a few samples
% disp('First few normalized inputs in MATLAB:');
% disp(X_test_loaded(:, 1:5));
%
% % Print model outputs for a few samples
% disp('First few model outputs in MATLAB:');
% for i = 1:5
% im = X_test_loaded(:, i);
% predictedLabels = net.evaluate(im);
% disp(predictedLabels);
% end

% Count total observations
total_obs = size(X_test_loaded, 2);
% disp(['There are total ', num2str(total_obs), ' observations']);

% %
% % Test accuracy --> verify matches with python
% %
% total_corr = 0;
% for i=1:total_obs
% im = X_test_loaded(:, i);
% predictedLabels = net.evaluate(im);
% [~, Pred] = min(predictedLabels);
% disp(Pred)
% TrueLabel = y_test_loaded(i);
% disp(TrueLabel)
% if Pred == TrueLabel
% total_corr = total_corr + 1;
% end
% end
% disp(['Test Accuracy: ', num2str(total_corr/total_obs)]);

% Number of observations we want to test
numObs = 100;

%% Verification

% to save results (robustness and time)
results = zeros(numObs,2);

% First, we define the reachability options
reachOptions = struct; % initialize
reachOptions.reachMethod = 'exact-star';
reachOptions.relaxFactor = 0.5;

nR = 50; % ---> just chosen arbitrarily

% ADJUST epsilon value here
% epsilon = [0.01];
epsilon = [0.0,0.001,0.01];
% epsilon = [0.00001];

%
% Set up results
%
nE = 3; %% will need to update later
res = zeros(numObs,nE); % robust result
time = zeros(numObs,nE); % computation time
met = repmat("exact", [numObs, nE]); % method used to compute result


% Randomly select observations
rng(500); % Set a seed for reproducibility
rand_indices = randsample(total_obs, numObs);

for e=1:length(epsilon)
% Reset the timeout flag
assignin('base', 'timeoutOccurred', false);

% Create and configure the timer
verificationTimer = timer;
verificationTimer.StartDelay = 600; % Set timer for 10 minutes
verificationTimer.TimerFcn = @(myTimerObj, thisEvent) ...
assignin('base', 'timeoutOccurred', true);
start(verificationTimer); % Start the timer

ce_count = 0;
exact_count = 0;
ap_count = 0;

for i=1:numObs
idx = rand_indices(i);
IS = perturbation(X_test_loaded(:, idx), epsilon(e), min_values, max_values);


t = tic; % Start timing the verification for each sample

temp = net.verify_robustness(IS, reachOptions, y_test_loaded(idx));
% disp(string(i)+" Exact: "+string(temp))
met(i,e) = 'exact';
res(i,e) = temp; % robust result
% end

time(i,e) = toc(t); % store computation time

% Check for timeout flag
if evalin('base', 'timeoutOccurred')
disp(['Timeout reached for epsilon = ', num2str(epsilon(e)), ': stopping verification for this epsilon.']);
res(i+1:end,e) = 2; % Mark remaining as unknown
break; % Exit the inner loop after timeout
end
end

% Summary results, stopping, and deleting the timer should be outside the inner loop
stop(verificationTimer);
delete(verificationTimer);

% Get summary results
N = numObs;
rob = sum(res(:,e)==1);
not_rob = sum(res(:,e) == 0);
unk = sum(res(:,e) == 2);
totalTime = sum(time(:,e));
avgTime = totalTime/N;

% Print results to screen
% fprintf('Model: %s\n', onnxFiles(k).name);
disp("======= ROBUSTNESS RESULTS e: "+string(epsilon(e))+" ==========")
disp(" ");
disp("Number of fair samples = "+string(rob)+ ", equivalent to " + string(100*rob/N) + "% of the samples.");
disp("Number of non-fair samples = " +string(not_rob)+ ", equivalent to " + string(100*not_rob/N) + "% of the samples.")
disp("Number of unknown samples = "+string(unk)+ ", equivalent to " + string(100*unk/N) + "% of the samples.");
disp(" ");
disp("It took a total of "+string(totalTime) + " seconds to compute the verification results, an average of "+string(avgTime)+" seconds per sample");
end
end


%% Helper Function
% Adjusted for fairness check -> only apply perturbation to desired feature.
function IS = perturbation(x, epsilon, min_values, max_values)
% Applies perturbations on selected features of input sample x
% Return an ImageStar (IS) and random images from initial set
SampleSize = size(x);

disturbance = zeros(SampleSize, "like", x);
sensitive_rows = [9];
nonsensitive_rows = [1,10,11,12];

% Flip the sensitive attribute
if x(sensitive_rows) == 1
x(sensitive_rows) = 0;
else
x(sensitive_rows) = 1;
end

% Apply epsilon perturbation to non-sensitive numerical features
for i = 1:length(nonsensitive_rows)
if nonsensitive_rows(i) <= size(x, 1)
disturbance(nonsensitive_rows(i), :) = epsilon;
else
error('The input data does not have enough rows.');
end
end

% Calculate disturbed lower and upper bounds considering min and max values
lb = max(x - disturbance, min_values);
ub = min(x + disturbance, max_values);
IS = ImageStar(single(lb), single(ub)); % default: single (assume onnx input models)
end

Binary file added code/nnv/examples/NN/Fair/adult_fair_data.mat
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_fairify2_data.mat
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_fairify_data.mat
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_model_fc.mat
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-1.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-10.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-11.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-12.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-2.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-3.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-4.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-5.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-6.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-7.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-8.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx/AC-9.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-1.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-10.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-11.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-12.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-2.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-3.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-4.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-5.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-6.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-7.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-8.onnx
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/adult_onnx2/AC-9.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added code/nnv/examples/NN/Fair/test_data.mat
Binary file not shown.
78 changes: 78 additions & 0 deletions code/nnv/examples/NN/Fair/training.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
%% Training of an ADULT classifier (NN)
% Code based on a few examples
% Data Preprocessing: https://github.com/jonathanlxy/MLproject-UCI-Adult-Income-Classification
% Ideas: https://github.com/LebronX/DeepGemini-public/blob/main/src/German/german_fairness_training.py
% Implementattion/Traihttps://github.com/LebronX/DeepGemini-public/blob/main/src/German/german_fairness_training.pyning

t = tic; % track total time for training

%% Read data
Train = csvread('finalset_cleaned_train.csv', 1, 0);
Test = csvread('finalset_cleaned_test.csv', 1, 0);

% For the training dataset
XTrain = Train(:, 1:end-1); % All rows, but exclude the last column
YTrain = Train(:, end); % All rows, only the last column

% For the testing dataset
XTest = Test(:, 1:end-1); % All rows, but exclude the last column
YTest = Test(:, end); % All rows, only the last column

YTrain = categorical(YTrain);
YTest = categorical(YTest);

N = 13; % Number of features after preprocessing
numClasses = 2; % For binary classification

%% Neural Network
layers = [
featureInputLayer(N)

fullyConnectedLayer(50)
reluLayer;

fullyConnectedLayer(100)
reluLayer;

fullyConnectedLayer(50)
reluLayer;

fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];

% Training options
options = trainingOptions('adam', ...
'InitialLearnRate', 0.01, ...
'MaxEpochs', 20, ...
'LearnRateSchedule', 'piecewise', ...
'LearnRateDropFactor', 0.2, ...
'LearnRateDropPeriod', 5, ...
'MiniBatchSize', 64, ...
'Shuffle', 'every-epoch', ...
'ValidationData', {XTest, YTest}, ...
'ValidationFrequency', 30, ...
'Verbose', true, ...
'Plots', 'training-progress');

% Train network
net = trainNetwork(XTrain,YTrain,layers,options);

% Get Accuracy
YPred = predict(net, XTest);
YPred = round(YPred);

% Convert YPred to categorical for comparison
YPredCategorical = categorical(YPred);

accuracy = sum(YPredCategorical == YTest) / numel(YTest);
disp("Validation accuracy = "+string(accuracy));

% Save model
disp("Saving model...");
save('adult_model_fc.mat', 'net', 'accuracy');

% Save test data for verification
save('test_data.mat', 'XTest', 'YTest');

toc(t);
Loading

0 comments on commit 1267379

Please sign in to comment.