Skip to content

Commit

Permalink
Playing with custom reach options for each benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
mldiego committed Jul 9, 2024
1 parent 3a89629 commit fe8f98e
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 70 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
%% Run as many benchmarks as possible from 2024

% vnncomp_path = "C:\Users\diego\Documents\Research\vnncomp2023_benchmarks\benchmarks\";
vnncomp_path = "/home/manzand/Documents/MATLAB/vnncomp2024_benchmarks/benchmarks/";
% vnncomp_path = "/home/manzand/Documents/MATLAB/vnncomp2024_benchmarks/benchmarks/";
vnncomp_path = "/home/dieman95/Documents/MATLAB/vnncomp2024_benchmarks/benchmarks/";

benchmarks = dir(vnncomp_path);

notSupported = {'test'; 'traffic_signs_recognition'; 'cctsdb_yolo'}; % skip for now
notSupported = {'test'; 'traffic_signs_recognition'; 'cctsdb_yolo', 'linearizenn'}; % skip for now, not even for falsification

regularTrack = {'acasxu'; 'nn4sys'; 'cora'; 'linearizenn'; 'safenlp'; 'dist-shift'; 'cifar100';...
'tinyimagenet'; 'cgan'; 'metaroom'; 'tllverifybench'; 'collins_rul';
Expand All @@ -15,8 +16,8 @@
'traffic_signs_recognition'; 'vggnet'; 'vit';
}; % we don't really care much about this one

for i=3:length(benchmarks)
% for i=3 % only do acasxu
% for i=3:length(benchmarks)
for i=3 % only do acasxu

name_noyear = split(benchmarks(i).name, "_");
name_noyear = strjoin(name_noyear(1:end-1), '_');
Expand Down
184 changes: 118 additions & 66 deletions code/nnv/examples/Submission/VNN_COMP2024/run_vnncomp2024_instance.m
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

% Load networks

[net, nnvnet, needReshape, reachOptions] = load_vnncomp_network(category, onnx, vnnlib);
[net, nnvnet, needReshape, reachOptionsList] = load_vnncomp_network(category, onnx, vnnlib);

inputSize = net.Layers(1, 1).InputSize;

Expand Down Expand Up @@ -60,45 +60,6 @@

%% 3) UNSAT?

% Define reachability options
% Let's try to choose this better from the get-go

% Option 1
% reachOptions_relax100 = struct;
% reachOptions_relax100.reachMethod = 'relax-star-range';
% reachOptions_relax100.relaxFactor = 1;
%
% % Option 2
% reachOptions_relax50 = struct;
% reachOptions_relax50.reachMethod = 'relax-star-range';
% reachOptions_relax50.relaxFactor = 0.5;
%
% % Option 3
% reachOptions_exact = struct;
% reachOptions_exact.reachMethod = 'exact-star';
% reachOptions_exact.reachOption = 'parallel';
% reachOptions_exact.numCores = feature('numcores');
%
% % Option 4
% reachOptions_approx.reachMethod = 'approx-star';
%
% % Choosing reachOptions (based on size, but not sure how to decice really...)
% if prod(inputSize) > 3000 % [32 32 3]
% reachOptions = {reachOptions_relax100; reachOptions_relax50; reachOptions_approx};
% else
% reachOptions = {reachOptions_relax50; reachOptions_approx; reachOptions_exact};
% end

% reachOptions = struct;
% reachOptions.lp_solver = "linprog"; % glpk is the worst, gurobi works better for some of the larger benchmarks, linprog faster for simple LPs
% reachOptions.reachMethod = 'approx-star';
% reachOptions.device = "gpu";
% reachOptions.reachMethod = 'exact-star';
% reachOptions.device = 'cpu';
% numCores = feature('numcores');
% reachOptions.numCores = numCores; % physical cores


% Check if property was violated earlier
if iscell(counterEx)
status = 0;
Expand All @@ -111,36 +72,64 @@
% Choose how to verify based on vnnlib file
if ~isa(lb, "cell") && length(prop) == 1 % one input, one output

IS = create_input_set(lb, ub, inputSize, needReshape);
while ~isempty(reachOptionsList)

reachOptions = reachOptionsList{1};

IS = create_input_set(lb, ub, inputSize, needReshape);

% Compute reachability
ySet = nnvnet.reach(IS, reachOptions);
% Compute reachability
ySet = nnvnet.reach(IS, reachOptions);

% Verify property
status = verify_specification(ySet, prop);
% Verify property
status = verify_specification(ySet, prop);

if status == 1 % verified, then stop
break
else
reachOptionsList = reachOptionsList(2:end);
end

end

elseif isa(lb, "cell") && length(lb) == length(prop) % multiple inputs, multiple outputs

local_status = ones(length(lb),1); % track status for each specification in the vnnlib
local_status = 2*ones(length(lb),1); % track status for each specification in the vnnlib

parfor spc = 1:length(lb) % We can compute these in parallel for faster computation

lb_spc = lb{spc};
ub_spc = ub{spc};

reachOptPar = reachOptionsList;

IS = create_input_set(lb_spc, ub_spc, inputSize, needReshape);

% Compute reachability
ySet = nnvnet.reach(IS, reachOptions);
while ~isempty(reachOptPar)

reachOptions = reachOptPar{1};

IS = create_input_set(lb_spc, ub_spc, inputSize, needReshape);

% Compute reachability
ySet = nnvnet.reach(IS, reachOptions);

% Verify property
if isempty(ySet.C)
dd = ySet.V; DD = ySet.V;
ySet = Star(dd,DD);
end

% Verify property
if isempty(ySet.C)
dd = ySet.V; DD = ySet.V;
ySet = Star(dd,DD);
% Add verification status
tempStatus = verify_specification(ySet, prop(spc));

if tempStatus ~= 2 % verified, then stop (or falsified)
break
else
reachOptPar = reachOptPar(2:end);
end

end

% Add verification status
local_status(spc) = verify_specification(ySet, prop(spc));
local_status(spc) = tempStatus;

end

Expand All @@ -153,20 +142,36 @@

elseif isa(lb, "cell") && length(prop) == 1 % one specification, multiple input definitions

local_status = ones(length(lb),1); % track status for each specification in the vnnlib
local_status = 2*ones(length(lb),1); % track status for each specification in the vnnlib, initialize as unknown

parfor spc = 1:length(lb) % We can compute these in parallel for faster computation

reachOptPar = reachOptionsList;

lb_spc = lb{spc};
ub_spc = ub{spc};

IS = create_input_set(lb_spc, ub_spc, inputSize, needReshape);
while ~isempty(reachOptPar)

reachOptions = reachOptPar{1};

IS = create_input_set(lb_spc, ub_spc, inputSize, needReshape);

% Compute reachability
ySet = nnvnet.reach(IS, reachOptions);

% Compute reachability
ySet = nnvnet.reach(IS, reachOptions);
% Add verification status
tempStatus = verify_specification(ySet, prop(spc));

% Verify property
local_status(spc) = verify_specification(ySet, prop);
if tempStatus ~= 2 % verified, then stop (or falsified)
break
else
reachOptPar = reachOptPar(2:end);
end

local_status(spc) = tempStatus;

end

end

Expand Down Expand Up @@ -250,7 +255,7 @@

end

function [net,nnvnet,needReshape,reachOptions] = load_vnncomp_network(category, onnx, vnnlib)
function [net,nnvnet,needReshape,reachOptionsList] = load_vnncomp_network(category, onnx, vnnlib)
% load participating vnncomp 2024 benchmark NNs
%
% Regular Track Benchmarks
Expand Down Expand Up @@ -286,14 +291,17 @@
%

needReshape = 0; % default is to use MATLAB reshape, otherwise use the python reshape
reachOptions = struct;
reachOptions.reachMethod = 'approx-star'; % default parameters
% reachOptions = struct;
% reachOptions.reachMethod = 'approx-star'; % default parameters
numCores = feature('numcores'); % in case we select exact method

if contains(category, 'collins_rul')
net = importNetworkFromONNX(onnx);
nnvnet = matlab2nnv(net);
needReshape = 2;
reachOptions = struct;
reachOptions.reachMethod = 'approx-star'; % default parameters
reachOptionsList{1} = reachOptions;

elseif contains(category, "nn4sys")
% nn4sys: onnx to matlab:
Expand All @@ -304,11 +312,15 @@
else
error("We don't have those");
end
reachOptions = struct;
reachOptions.reachMethod = 'approx-star'; % default parameters
reachOptionsList{1} = reachOptions;

elseif contains(category, "ml4acopf")
% ml4acopf: onnx to matlab
net = importNetworkFromONNX(onnx, "InputDataFormats", "BC");
nnvnet = "";
reachOptionsList = {};

elseif contains(category, "dist_shift")
% dist_shift: onnx to matlab, , matlab to nnv?
Expand All @@ -318,6 +330,9 @@
catch
nnvnet = "";
end
reachOptions = struct;
reachOptions.reachMethod = 'approx-star'; % default parameters
reachOptionsList{1} = reachOptions;

elseif contains(category, "cgan")
% cgan: onnx to nnv
Expand All @@ -327,23 +342,33 @@
else
error("We don't have those");
end
reachOptions = struct;
reachOptions.reachMethod = 'approx-star'; % default parameters
reachOptionsList{1} = reachOptions;

elseif contains(category, "vggnet16")
% vgg16: onnx to matlab
net = importNetworkFromONNX(onnx); % flattenlayer
nnvnet = "";
needReshape = 1;
reachOptions = struct;
reachOptions.reachMethod = 'approx-star'; % default parameters
reachOptionsList{1} = reachOptions;

elseif contains(category, "tllverify")
% tllverify: onnx to nnv
net = importNetworkFromONNX(onnx,"InputDataFormats", "BC", 'OutputDataFormats',"BC");
nnvnet = matlab2nnv(net);
reachOptions = struct;
reachOptions.reachMethod = 'approx-star'; % default parameters
reachOptionsList{1} = reachOptions;

elseif contains(category, "vit")
% vit: onnx to matlab
net = importNetworkFromONNX(onnx, "InputDataFormats", "BCSS", 'OutputDataFormats',"BC");
nnvnet = "";
needReshape= 1;
reachOptionsList = {};

elseif contains(category, "cctsdb_yolo")
% cctsdb_yolo: onnx to matlab
Expand All @@ -357,32 +382,50 @@
% collins_yolo: onnx to matlab
net = importNetworkFromONNX(onnx);
nnvnet = "";
reachOptionsList = {};

elseif contains(category, "yolo")
% yolo: onnx to nnv
net = importNetworkFromONNX(onnx); % padlayer
nnvnet = matlab2nnv(net);
% needReshape = ?
reachOptions = struct;
reachOptions.reachMethod = 'approx-star'; % default parameters
reachOptionsList{1} = reachOptions;

elseif contains(category, "acasxu")
% acasxu: onnx to nnv
net = importNetworkFromONNX(onnx, "InputDataFormats","BCSS");
nnvnet = matlab2nnv(net);
if contains(vnnlib, "prop_1.") || contains(vnnlib, "prop_2.")
if ~contains(vnnlib, "prop_3.") && ~contains(vnnlib, "prop_4.")
reachOptions.reachMethod = 'exact-star';
reachOptions.numCores = numCores;
reachOptionsList{1} = reachOptions;
else
reachOptions = struct;
reachOptions.reachMethod = 'approx-star'; % default parameters
reachOptionsList{1} = reachOptions;
reachOptions.reachMethod = 'exact-star';
reachOptions.numCores = numCores;
reachOptionsList{2} = reachOptions;
end

elseif contains(category, "cifar100")
% cifar100: onnx to nnv
net = importNetworkFromONNX(onnx, "InputDataFormats","BCSS", "OutputDataFormats","BC");
nnvnet = matlab2nnv(net);
needReshape = 1;
reachOptions = struct;
reachOptions.reachMethod = 'approx-star'; % default parameters
reachOptionsList{1} = reachOptions;

elseif contains(category, "tinyimagenet")
% tinyimagenet: onnx to nnv
net = importNetworkFromONNX(onnx, "InputDataFormats","BCSS", "OutputDataFormats","BC");
nnvnet = matlab2nnv(net);
reachOptions = struct;
reachOptions.reachMethod = 'approx-star'; % default parameters
reachOptionsList{1} = reachOptions;
% needReshpae = ?

% elseif contains(category, "linearizenn")% we do not support the current version
Expand All @@ -396,12 +439,18 @@
net = importNetworkFromONNX(onnx, "InputDataFormats","BC", "OutputDataFormats","BC");
nnvnet = matlab2nnv(net);
% needReshape = ?
reachOptions = struct;
reachOptions.reachMethod = 'approx-star'; % default parameters
reachOptionsList{1} = reachOptions;

elseif contains(category, "cora")
% cora benchmark: onnx 2 nnv
net = importNetworkFromONNX(onnx, "InputDataFormats","BC", "OutputDataFormats","BC");
nnvnet = matlab2nnv(net);
% needReshape = ?
reachOptions = struct;
reachOptions.reachMethod = 'approx-star'; % default parameters
reachOptionsList{1} = reachOptions;

elseif contains(category, "lsnc")
% lyapunov benchmark: onnx to nnv (barely, some IR and opset version differences)
Expand All @@ -420,6 +469,9 @@
net = importNetworkFromONNX(onnx, "InputDataFormats","BCSS", "OutputDataFormats","BC");
nnvnet = matlab2nnv(net);
needReshape = 2;
reachOptions = struct;
reachOptions.reachMethod = 'approx-star'; % default parameters
reachOptionsList{1} = reachOptions;

else % all other benchmarks
% traffic: onnx to matlab: opset15 issues
Expand Down

0 comments on commit fe8f98e

Please sign in to comment.