Skip to content

Commit

Permalink
Fix precision error on CAV23, add tests and other small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mldiego committed Jan 12, 2024
1 parent 044b59a commit d75dfde
Show file tree
Hide file tree
Showing 11 changed files with 219 additions and 61 deletions.
23 changes: 22 additions & 1 deletion code/nnv/engine/nn/NN.m
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,15 @@
if strcmp(obj.dis_opt, 'display')
fprintf('\nPerform reachability analysis for the network %s \n', obj.Name);
end

% ensure NN parameters and input set share same precision
inputSet = obj.consistentPrecision(inputSet); % change only input, this can be changed in the future

% Perform reachability based on connections or assume no skip/sparse connections
if isempty(obj.Connections)
outputSet = obj.reach_noConns(inputSet);
else
outputSet = obj.reach_withConns(inputSet);

end

end
Expand Down Expand Up @@ -945,6 +947,25 @@
reachOptions.numCores = 1;
end
end

% Ensure input and parameter precision is the same
function inputSet = consistentPrecision(obj, inputSet)
% (assume parameters have same precision across layers)
% approach: change input precision based on network parameters
inputPrecision = class(inputSet.V);
netPrecision = 'double'; % default
for i=1:length(obj.Layers)
if isa(obj.Layers{i}, "FullyConnectedLayer") || isa(obj.Layers{i}, "Conv2DLayer")
netPrecision = class(obj.Layers{i}.Weights);
break;
end
end
if ~strcmp(inputPrecision, netPrecision)
% input and parameter precision does not match
warning("Changing input set precision to "+string(netPrecision));
inputSet = inputSet.changeVarsPrecision(netPrecision);
end
end

% Create input set based on input vector and bounds
function R = create_input_set(obj, x_in, disturbance, lb_allowable, ub_allowable) % assume tol is applied to every vale of the input
Expand Down
24 changes: 24 additions & 0 deletions code/nnv/engine/set/ImageStar.m
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,30 @@

end

% change variable precision
function S = changeVarsPrecision(obj, precision)
S = obj;
if strcmp(precision, 'single')
S.V = single(S.V);
S.C = single(S.C);
S.d = single(S.d);
S.pred_lb = single(S.pred_lb);
S.pred_ub = single(S.pred_lb);
S.im_lb = single(S.im_lb);
S.im_ub = single(S.im_ub);
elseif strcmp(precision, 'double')
S.V = double(S.V);
S.C = double(S.C);
S.d = double(S.d);
S.pred_lb = double(S.pred_lb);
S.pred_ub = double(S.pred_lb);
S.im_lb = double(S.im_lb);
S.im_ub = double(S.im_ub);
else
error("Only single or double precision arrays allowed. GpuArray/dlarray are coming.")
end
end

end


Expand Down
17 changes: 16 additions & 1 deletion code/nnv/engine/set/ImageZono.m
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,22 @@
S = obj.toImageStar;
S = S.toStar;
end


% change variable precision
function S = changeVarsPrecision(obj, precision)
S = obj;
if strcmp(precision, 'single')
S.V = single(S.V);
S.lb_image = single(S.lb_image);
S.ub_image = single(S.ub_image);
elseif strcmp(precision, 'double')
S.V = double(S.V);
S.lb_image = double(S.lb_image);
S.ub_image = double(S.ub_image);
else
error("Only single or double precision arrays allowed. GpuArray/dlarray are coming.")
end
end
end


Expand Down
34 changes: 32 additions & 2 deletions code/nnv/engine/set/Star.m
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,30 @@

end

% change variable precision
function S = changeVarsPrecision(obj, precision)
S = obj;
if strcmp(precision, 'single')
S.V = single(S.V);
S.C = single(S.C);
S.d = single(S.d);
S.predicate_lb = single(S.predicate_lb);
S.predicate_ub = single(S.predicate_lb);
S.state_lb = single(S.state_lb);
S.state_ub = single(S.state_ub);
elseif strcmp(precision, 'double')
S.V = double(S.V);
S.C = double(S.C);
S.d = double(S.d);
S.predicate_lb = double(S.predicate_lb);
S.predicate_ub = double(S.predicate_lb);
S.state_lb = double(S.state_lb);
S.state_ub = double(S.state_ub);
else
error("Only single or double precision arrays allowed. GpuArray/dlarray are coming.")
end
end

end


Expand Down Expand Up @@ -1526,8 +1550,14 @@ function plot(varargin)
end
end
else
P = obj.toPolyhedron;
P.plot('color', color);
if isa(obj.V, 'single') || isa(obj.C,'single') || isa(obj.d, 'single')
S = obj.changeVarsPrecision('double');
P = S.toPolyhedron;
P.plot('color', color);
else
P = obj.toPolyhedron;
P.plot('color', color);
end
end

else
Expand Down
25 changes: 24 additions & 1 deletion code/nnv/engine/set/VolumeStar.m
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,30 @@

end

% TODO: add a 3D projection to ImageStar
% change variable precision
function S = changeVarsPrecision(obj, precision)
S = obj;
if strcmp(precision, 'single')
S.V = single(S.V);
S.C = single(S.C);
S.d = single(S.d);
S.pred_lb = single(S.pred_lb);
S.pred_ub = single(S.pred_lb);
S.vol_lb = single(S.vol_lb);
S.vol_ub = single(S.vol_ub);
elseif strcmp(precision, 'double')
S.V = double(S.V);
S.C = double(S.C);
S.d = double(S.d);
S.pred_lb = double(S.pred_lb);
S.pred_ub = double(S.pred_lb);
S.vol_lb = double(S.vol_lb);
S.vol_ub = double(S.vol_ub);
else
error("Only single or double precision arrays allowed. GpuArray/dlarray are coming.")
end
end

end


Expand Down
14 changes: 14 additions & 0 deletions code/nnv/engine/set/Zono.m
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,20 @@
imageStar = im1.toImageStar(height, width, numChannels);

end

% change variable precision
function S = changeVarsPrecision(obj, precision)
S = obj;
if strcmp(precision, 'single')
S.V = single(S.V);
S.c = single(S.c);
elseif strcmp(precision, 'double')
S.V = double(S.V);
S.c = double(S.c);
else
error("Only single or double precision arrays allowed. GpuArray/dlarray are coming.")
end
end

end

Expand Down
14 changes: 7 additions & 7 deletions code/nnv/engine/utils/load_vnnlib.m
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
dim = dim + 1; % only have seen inputs defined as vectors, so this should work
% the more general approach would require some extra work, but should be easy as well
elseif contains(tline, "declare-const") && contains(tline, "Y_")
lb_template = zeros(dim,1);
ub_template = zeros(dim,1);
lb_template = zeros(dim,1,'single');
ub_template = zeros(dim,1,'single');
dim = 0; % reset dimension counter
phase = "DeclareOutput";
continue; % redo this line in correct phase
Expand All @@ -51,7 +51,7 @@
dim = 1; % reset dimension counter
phase = "DefineInput";
% Initialize variables
lb_input = lb_template;
lb_input = lb_template;
ub_input = ub_template;
continue; % redo this line in correct phase
end
Expand Down Expand Up @@ -240,7 +240,7 @@
H(idx2) = -1;
else
var2 = split(var2, ')');
g = str2double(var2{1});
g = single(str2double(var2{1}));
end
else
H(idx1) = -1;
Expand All @@ -252,7 +252,7 @@
H(idx2) = 1;
else
var2 = split(var2, ')');
g = -str2double(var2{1});
g = -single(str2double(var2{1}));
end
end
% Add constraint (H, g) to assertion variable (ast)
Expand Down Expand Up @@ -346,7 +346,7 @@
dim = split(t{2},'_');
dim = str2double(dim{2})+1;
value = split(t{3},')');
value = str2double(value{1});
value = single(str2double(value{1}));
if contains(t{1},">=")
lb_input(dim) = value;
else
Expand Down Expand Up @@ -401,7 +401,7 @@
dim = split(t{2},'_');
dim = str2double(dim{2})+1;
value = split(t{3},')');
value = str2double(value{1});
value = single(str2double(value{1}));
if contains(t{1},">=") || contains(t{1}, ">")
Hvec(dim) = -1;
gval = -value;
Expand Down
5 changes: 3 additions & 2 deletions code/nnv/examples/Tutorial/other/set_representations.m
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,12 @@

S1 = S.affineMap([1 0 0 0 0; 0 1 0 0 0], []); % 2D, dims 1 and 2
figure;
Star.plot(S1);
Star.plotBoxes_2D(S1,1,2,'r');
hold on;

S3 = S.affineMap([1 0 0 0 0; 0 1 0 0 0; 0 0 1 0 0], []); % 3D, dims 1,2 and 3
figure;
Star.plot(S3);
Star.plotBoxes_3D(S3,1,2,3,'r');

% 2) Plot an overapproximation (box around Star) of the set and plot

Expand Down
Loading

0 comments on commit d75dfde

Please sign in to comment.