forked from verivital/nnv
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
337 changed files
with
632 additions
and
177 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
function I = remove_voxels(vol, voxels, noise_disturbance) | ||
% noise_disturnamce can be kept fixed here, more interesting on number | ||
% of voxels changed | ||
|
||
% Return a VolumeStar of a brightening attack on a few pixels | ||
|
||
% Initialize vars | ||
ct = 0; % keep track of pixels modified | ||
flag = 0; % determine when to stop modifying pixels | ||
vol = single(vol); | ||
at_vol = vol; | ||
|
||
% Like darkening attack | ||
for i=1:size(vol,1) | ||
for j=1:size(vol,2) | ||
for k=1:size(vol,3) | ||
if vol(i,j,k) < threshold | ||
at_vol(i,j,k) = 255; | ||
ct = ct + 1; | ||
if ct >= voxels | ||
flag = 1; | ||
break; | ||
end | ||
end | ||
end | ||
if flag == 1 | ||
break | ||
end | ||
end | ||
if flag == 1 | ||
break; | ||
end | ||
end | ||
|
||
% Define input set as VolumeStar | ||
dif_vol = -vol + at_vol; | ||
noise = dif_vol; | ||
V(:,:,:,:,1) = vol; % center of set | ||
V(:,:,:,:,2) = noise; % basis vectors | ||
C = [1; -1]; % constraints | ||
d = [1; -1]; % constraints | ||
I = VolumeStar(V, C, d, 1-noise_disturbance, 1); % input set | ||
|
||
|
||
end |
2 changes: 0 additions & 2 deletions
2
code/nnv/examples/Submission/WiP_3d/functions/summarize_results.m
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
|
||
# NAV Benchmark | ||
|
||
## Property: | ||
The control goal is to navigate a robot to a goal region while avoiding an obstacle. | ||
Time horizon: `t = 6s`. Control period: `0.2s`. | ||
|
||
Initial states: | ||
|
||
x1 = [2.9, 3.1] | ||
x2 = [2.9, 3.1] | ||
x3 = [0, 0] | ||
x4 = [0, 0] | ||
|
||
Dynamic system: [dynamics.m](./dynamics.m) | ||
|
||
Goal region ( t=6 ): | ||
|
||
x1 = [-0.5, 0.5] | ||
x2 = [-0.5, 0.5] | ||
x3 = [-Inf, Inf] | ||
x4 = [-Inf, Inf] | ||
|
||
Obstacle ( always ): | ||
|
||
x1 = [1, 2] | ||
x2 = [1, 2] | ||
x3 = [-Inf, Inf] | ||
x4 = [-Inf, Inf] | ||
|
||
## Networks: | ||
|
||
We provide two networks: | ||
- The first network is trained with standard (point-based) reinforcement learning: `nn-nav-point.onnx` | ||
- The second network is trained set-based to improve its verifiable robustness by integrating reachability analysis into the training process: `nn-nav-set.onnx` | ||
|
||
Reference set-based training: https://arxiv.org/abs/2401.14961 | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
function dx = dynamics(x,u) | ||
|
||
dx = [ | ||
x(3)*cos(x(4)); | ||
x(3)*sin(x(4)); | ||
u(1); | ||
u(2) | ||
]; | ||
|
||
end | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file added
BIN
+10.4 KB
code/nnv/examples/Submission/WiP_3d/other/NAV/networks/nn-nav-point.onnx
Binary file not shown.
Binary file added
BIN
+10.4 KB
code/nnv/examples/Submission/WiP_3d/other/NAV/networks/nn-nav-set.onnx
Binary file not shown.
135 changes: 135 additions & 0 deletions
135
code/nnv/examples/Submission/WiP_3d/other/NAV/reach_point.m
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
function rT = reach_point() | ||
|
||
%% Reachability analysis of NAV Benchmark | ||
|
||
%% Load Components | ||
|
||
% Load the controller | ||
%netonnx = importNetworkFromONNX('networks/nn-nav-point.onnx', "InputDataFormats", "BC"); | ||
netonnx = importONNXNetwork('networks/nn-nav-point.onnx', "InputDataFormats", "BC"); | ||
|
||
% Load plant | ||
reachStep = 0.02; | ||
controlPeriod = 0.2; | ||
% plant = NonLinearODE(4, 2, @dynamics, reachStep, controlPeriod, eye(4)); | ||
% plant.set_tensorOrder(2); | ||
% plant.set_taylorTerms(3); | ||
% plant.set_zonotopeOrder(100); | ||
% plant.set_intermediateOrder(50); | ||
|
||
%% Reachability analysis | ||
|
||
% Initial set | ||
lb = [2.9; 2.9; 0; 0]; | ||
ub = [3.1; 3.1; 0; 0]; | ||
init_set = Box(lb,ub); | ||
init = init_set.partition([1 2],[50 50]); | ||
|
||
% Reachability options | ||
num_steps = 20; | ||
reachOptions.reachMethod = 'approx-star'; | ||
|
||
N = length(init); | ||
disp("Verifying "+string(N)+" samples...") | ||
|
||
mkdir('tmp'); | ||
parpool("Processes"); % initialize parallel process | ||
|
||
% Execute reachabilty analysis | ||
t = tic; | ||
parfor j = 1:length(init) | ||
% Get NNV network | ||
net = matlab2nnv(netonnx); | ||
% Create plant | ||
plant = NonLinearODE(4, 2, @dynamics, reachStep, controlPeriod, eye(4)); | ||
plant.set_tensorOrder(2); | ||
plant.set_taylorTerms(3); | ||
plant.set_zonotopeOrder(100); | ||
plant.set_intermediateOrder(50); | ||
% Get initial conditions | ||
init_set = init(j).toStar; | ||
%reachSub = init_set; | ||
for i = 1:num_steps | ||
% Compute controller output set | ||
input_set = net.reach(init_set,reachOptions); | ||
|
||
% Compute plant reachable set | ||
init_set = plantReach(plant, init_set, input_set,'lin'); | ||
end | ||
toc(t); | ||
parsave("tmp/reachSet"+string(j)+".mat",plant); | ||
end | ||
rT = toc(t); % get reach time | ||
disp("Finished reachability...") | ||
|
||
% Shut Down Current Parallel Pool | ||
poolobj = gcp('nocreate'); | ||
delete(poolobj); | ||
|
||
% Save results | ||
if is_codeocean | ||
save('/results/logs/nav_point.mat', 'rT','-v7.3'); | ||
else | ||
save('nav_point.mat', 'rT','-v7.3'); | ||
end | ||
|
||
|
||
%% Visualize results | ||
setFiles = dir('tmp/*.mat'); | ||
|
||
t = tic; | ||
|
||
f = figure; | ||
rectangle('Position',[-0.5,-0.5,1,1],'FaceColor',[0 0.5 0 0.5],'EdgeColor','y', 'LineWidth', 0.1); % goal region | ||
hold on; | ||
rectangle('Position',[1,1,1,1],'FaceColor',[0.7 0 0 0.8], 'EdgeColor','r', 'LineWidth', 0.1); % obstacle | ||
grid; | ||
for K = 1 : length(setFiles) | ||
if ~mod(K,50) | ||
disp("Plotting partition "+string(K)+" ..."); | ||
toc(t) | ||
pause(0.01); % to ensure it prints | ||
end | ||
res = load("tmp/"+setFiles(K).name); | ||
plant = res.plant; | ||
for k=1:length(plant.cora_set) | ||
plot(plant.cora_set{k}, [1,2], 'b', 'Unify', true); | ||
end | ||
end | ||
hold on; | ||
xlabel('x1'); | ||
ylabel('x2'); | ||
|
||
disp("Finished plotting all reach sets"); | ||
|
||
%% Save figure | ||
if is_codeocean | ||
saveas(f,'/results/logs/nav_point.png'); | ||
% exportgraphics(f,'/results/logs/nav-set.pdf', 'ContentType', 'vector'); | ||
else | ||
saveas(f,'nav_point_21.png'); | ||
% exportgraphics(f,'nav-set.pdf','ContentType', 'vector'); | ||
end | ||
|
||
% Save results | ||
if is_codeocean | ||
save('/results/logs/nav_point.mat','rT','-v7.3'); | ||
else | ||
save('nav_point.mat', 'rT','-v7.3'); | ||
end | ||
|
||
end | ||
|
||
%% Helper function | ||
function init_set = plantReach(plant,init_set,input_set,algoC) | ||
nS = length(init_set); % based on approx-star, number of sets should be equal | ||
ss = []; | ||
for k=1:nS | ||
ss =[ss plant.stepReachStar(init_set(k), input_set(k),algoC)]; | ||
end | ||
init_set = ss; | ||
end | ||
|
||
function parsave(fname, plant) % trick to save while on parpool | ||
save(fname, 'plant') | ||
end |
129 changes: 129 additions & 0 deletions
129
code/nnv/examples/Submission/WiP_3d/other/NAV/reach_set.m
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
function rT = reach_set() | ||
|
||
%% Reachability analysis of NAV Benchmark | ||
|
||
%% Load Components | ||
|
||
% Load the controller | ||
% net = importNetworkFromONNX('networks/nn-nav-set.onnx', "InputDataFormats", "BC"); | ||
netonnx = importONNXNetwork('networks/nn-nav-set.onnx', "InputDataFormats", "BC"); | ||
% Load plant | ||
reachStep = 0.02; | ||
controlPeriod = 0.2; | ||
% plant = NonLinearODE(4, 2, @dynamics, reachStep, controlPeriod, eye(4)); | ||
% plant.set_tensorOrder(2); | ||
% plant.set_taylorTerms(3); | ||
% plant.set_zonotopeOrder(100); | ||
% plant.set_intermediateOrder(50); | ||
|
||
|
||
%% Reachability analysis | ||
|
||
% Initial set | ||
lb = [2.9; 2.9; 0; 0]; | ||
ub = [3.1; 3.1; 0; 0]; | ||
init_set = Box(lb,ub); | ||
init = init_set.partition([1 2],[50 50]); | ||
|
||
% Reachability options | ||
num_steps = 21; | ||
reachOptions.reachMethod = 'approx-star'; | ||
|
||
N = length(init); | ||
disp("Verifying "+string(N)+" samples...") | ||
|
||
mkdir('temp'); | ||
parpool("Processes"); % initialize parallel process | ||
|
||
% Execute reachabilty analysis | ||
t = tic; | ||
parfor j = 1:length(init) | ||
% Get NNV network | ||
net = matlab2nnv(netonnx); | ||
% Create plant | ||
plant = NonLinearODE(4, 2, @dynamics, reachStep, controlPeriod, eye(4)); | ||
plant.set_tensorOrder(2); | ||
plant.set_taylorTerms(3); | ||
plant.set_zonotopeOrder(100); | ||
plant.set_intermediateOrder(50); | ||
% Get initial conditions | ||
init_set = init(j).toStar; | ||
%reachSub = init_set; | ||
for i = 1:num_steps | ||
% Compute controller output set | ||
input_set = net.reach(init_set,reachOptions); | ||
|
||
% Compute plant reachable set | ||
init_set = plantReach(plant, init_set, input_set,'lin'); | ||
end | ||
toc(t); | ||
parsave("temp/reachSet"+string(j)+".mat",plant); | ||
end | ||
rT = toc(t); % get reach time | ||
disp("Finished reachability...") | ||
|
||
% Shut Down Current Parallel Pool | ||
poolobj = gcp('nocreate'); | ||
delete(poolobj); | ||
|
||
%% Visualize results | ||
setFiles = dir('temp/*.mat'); | ||
|
||
f = figure; | ||
rectangle('Position',[-0.5,-0.5,1,1],'FaceColor',[0 0.5 0 0.5],'EdgeColor','y', 'LineWidth', 0.1); % goal region | ||
hold on; | ||
rectangle('Position',[1,1,1,1],'FaceColor',[0.7 0 0 0.8], 'EdgeColor','r', 'LineWidth', 0.1); % obstacle | ||
grid; | ||
t = tic; | ||
for K = 1 : length(setFiles) | ||
if ~mod(K,50) | ||
disp("Plotting partition "+string(K)+" ..."); | ||
toc(t) | ||
pause(0.01); % to ensure it prints | ||
end | ||
res = load("temp/"+setFiles(K).name); | ||
plant = res.plant; | ||
% plant.get_interval_sets; | ||
% Star.plotBoxes_2D_noFill(plant.intermediate_reachSet, 1,2,'b'); | ||
for k=1:(length(plant.cora_set)) | ||
plot(plant.cora_set{k}, [1,2], 'b', 'Unify', true); | ||
end | ||
end | ||
hold on; | ||
xlabel('x1'); | ||
ylabel('x2'); | ||
|
||
disp("Finished plotting all reach sets"); | ||
|
||
|
||
%% Save figure | ||
if is_codeocean | ||
saveas(f,'/results/logs/nav_set.png'); | ||
% exportgraphics(f,'/results/logs/nav-set.pdf', 'ContentType', 'vector'); | ||
else | ||
saveas(f,'nav_set.png'); | ||
% exportgraphics(f,'nav-set.pdf','ContentType', 'vector'); | ||
end | ||
|
||
% Save results | ||
if is_codeocean | ||
save('/results/logs/nav_set.mat','rT','-v7.3'); | ||
else | ||
save('nav_set.mat', 'rT','-v7.3'); | ||
end | ||
|
||
end | ||
|
||
%% Helper function | ||
function init_set = plantReach(plant,init_set,input_set,algoC) | ||
nS = length(init_set); % based on approx-star, number of sets should be equal | ||
ss = []; | ||
for k=1:nS | ||
ss =[ss plant.stepReachStar(init_set(k), input_set(k),algoC)]; | ||
end | ||
init_set = ss; | ||
end | ||
|
||
function parsave(fname, plant) % trick to save while on parpool | ||
save(fname, 'plant') | ||
end |
Binary file added
BIN
+35.3 KB
code/nnv/examples/Submission/WiP_3d/other/Single_Pendulum/controller_single_pendulum.h5
Binary file not shown.
Binary file added
BIN
+6.5 KB
code/nnv/examples/Submission/WiP_3d/other/Single_Pendulum/controller_single_pendulum.mat
Binary file not shown.
Oops, something went wrong.