Skip to content

Commit

Permalink
Initial run for new archcomp benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
mldiego committed May 31, 2024
1 parent 53fec8b commit 09834f6
Show file tree
Hide file tree
Showing 15 changed files with 318 additions and 2 deletions.
12 changes: 10 additions & 2 deletions code/nnv/engine/nn/layers/ElementwiseAffineLayer.m
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
end
% Offset (bias)
if obj.DoOffset
image = image.affineMap(diag(ones(1,a.dim)), obj.Offset); % x + b
image = image.affineMap(diag(ones(1,image.dim)), obj.Offset); % x + b
end
end

Expand All @@ -150,7 +150,15 @@
% @S: output ImageStar

n = length(inputs);
S(n) = ImageStar;

% Initialize output variables
if isa(inputs, "ImageStar")
S(n) = ImageStar;
else
S(n) = Star;
end

% Begin computing reachability one set at a time
if strcmp(option, 'parallel')
parfor i=1:n
S(i) = obj.reach_star_single_input(inputs(i));
Expand Down
5 changes: 5 additions & 0 deletions code/nnv/engine/set/Star.m
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,13 @@

if ~isempty(obj.state_lb) && ~isempty(obj.state_ub)
B = Box(obj.state_lb, obj.state_ub);

else

if isa(obj.V, 'single') || isa(obj.C, 'single') || isa(obj.d, 'single')
obj = obj.changeVarsPrecision('double');
end

lb = zeros(obj.dim, 1);
ub = zeros(obj.dim, 1);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Proposed ARCH Benchmark - CartPole

This benchmark is proposed for the ARCH Friendly Competition 2024.

## Benchmark

We consider a pendulum (pole) mounted on a movable cart (= CartPole). The cart can be moved by a controller. The carts postition x1, its velocity x2, the angle of the pole x3 (with 0, 2*pi being the upright postion) and its angular velocity x4 define the state vector of the 4-dimensional system. The system starts in a middle postition of the cart, with the pendulum in the upright position. The controllers goal is to counteract slight deviations in the starting values and balance the pendulum in the middle of the track.

## Specifications and Dynamics

The system dynamics can be found in ```dynamics.m```. The safe states are defined as a stable upright position, which has to be reached after 8 seconds and has to be hold for at least 2 seconds. The controllers step size is 0.02 seconds.
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
function dx = dynamics(x, f)

% Cartpole Swingup, 4 states x and the input f (action of the controller)
% The controller takes (x(1), x(2), x(3), x(4)) as input, its output can
% be used directly.

dx(1,1) = x(2);
dx(2,1) = 2*f;
dx(3,1) = x(4);
dx(4,1) = (0.08*0.41*(9.8*sin(x(3))-2*f*cos(x(3)))-0.0021*x(4))/0.0105;

end
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
% function t = reach()

%% Reachability analysis of Cartpole Benchmark

%% Load Components

% Load the controller
net = importNetworkFromONNX('model.onnx', "InputDataFormats", "BC");
net = matlab2nnv(net);
% Load plant
reachStep = 0.002;
controlPeriod = 0.02;
plant = NonLinearODE(4,1,@dynamics, reachStep, controlPeriod, eye(4));
plant.set_tensorOrder(2);


%% Reachability analysis

% Initial set
lb = [-0.1; -0.05; -0.1; -0.05];
ub = [0.1; 0.05; 0.1; 0.05];
InitialSet = Box(lb,ub);
init_sets = InitialSet.partition([1,2,3,4],[20,10,20,10]);
for k=1:length(init_sets)
init_set = init_sets(k);
init_set = init_set.toStar;
% Store all reachable sets
reachAll = init_set;
num_steps = 500;
reachOptions.reachMethod = 'approx-star';
t = tic;
for i=1:num_steps
disp(i);
% Compute controller output set
input_set = net.reach(init_set,reachOptions);
% Compute plant reachable set
init_set = plant.stepReachStar(init_set, input_set,'lin');
reachAll = [reachAll init_set];
toc(t);
end
t = toc(t);
end

%% Visualize results
plant.get_interval_sets;

f = figure;
hold on;
% rectangle('Position',[0.5,1,1,1],'FaceColor',[1 0 0 0.5],'EdgeColor','r', 'LineWidth',0.1)
Star.plotBoxes_2D_noFill(plant.intermediate_reachSet,1,2,'b');
Star.plotBoxes_2D_noFill(plant.intermediate_reachSet,3,4,'r');
% Plot only falsifying trace
% plot(squeeze(sims(3,k,:)), squeeze(sims(1,k,:)), 'Color', [0 0 1 1]);
grid;
% xlabel('Time (s)');
% ylabel('\theta');
% xlim([0 0.6])
% ylim([0.95 1.25])
% Save figure
% if is_codeocean
% exportgraphics(f,'/results/logs/cartpole.pdf', 'ContentType', 'vector');
% else
% exportgraphics(f,'cartpole.pdf','ContentType', 'vector');
% end

% end
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Initial states:

x1 = [-0.1, 0.1]
x2 = [-0.05, 0.05]
x3 = [-0.1, 0.1]
x4 = [-0.05, 0.05]

t = 10 seconds
control period 0.02 s

Property:

For t > 8.0 s
x1, x3, x4 should be in [-0.001, 0.001]
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

# NAV Benchmark

## Property:
The control goal is to navigate a robot to a goal region while avoiding an obstacle.
Time horizon: `t = 6s`. Control period: `0.2s`.

Initial states:

x1 = [2.9, 3.1]
x2 = [2.9, 3.1]
x3 = [0, 0]
x4 = [0, 0]

Dynamic system: [dynamics.m](./dynamics.m)

Goal region ( t=6 ):

x1 = [-0.5, 0.5]
x2 = [-0.5, 0.5]
x3 = [-Inf, Inf]
x4 = [-Inf, Inf]

Obstacle ( always ):

x1 = [1, 2]
x2 = [1, 2]
x3 = [-Inf, Inf]
x4 = [-Inf, Inf]

## Networks:

We provide two networks:
- The first network is trained with standard (point-based) reinforcement learning: `nn-nav-point.onnx`
- The second network is trained set-based to improve its verifiable robustness by integrating reachability analysis into the training process: `nn-nav-set.onnx`

Reference set-based training: https://arxiv.org/abs/2401.14961


Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function dx = dynamics(x,u)

dx = [
x(3)*cos(x(4));
x(3)*sin(x(4));
u(1);
u(2)
];

end

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
function t = reach_point()

%% Reachability analysis of NAV Benchmark

%% Load Components

% Load the controller
net = importNetworkFromONNX('networks/nn-nav-point.onnx', "InputDataFormats", "BC");
net = matlab2nnv(net);
% Load plant
reachStep = 0.002;
controlPeriod = 0.02;
plant = NonLinearODE(4, 2, @dynamics, reachStep, controlPeriod, eye(4));
plant.set_tensorOrder(2);


%% Reachability analysis

% Initial set
lb = [2.9; 2.9; 0; 0];
ub = [3.1; 3.1; 0; 0];
init_set = Star(lb,ub);

% Store all reachable sets
reachAll = init_set;

% Reachability options
num_steps = 30;
reachOptions.reachMethod = 'approx-star';

% Begin computation
t = tic;
for i=1:num_steps
disp(i);

% Compute controller output set
input_set = net.reach(init_set,reachOptions);

% Compute plant reachable set
init_set = plant.stepReachStar(init_set, input_set,'lin');
reachAll = [reachAll init_set];
toc(t);

end

t = toc(t);


%% Visualize results
plant.get_interval_sets;

f2 = figure;
% rectangle('Position',[-1,-1,2,2],'FaceColor',[0 0.5 0 0.5],'EdgeColor','y', 'LineWidth',0.1)
hold on;
Star.plotBoxes_2D_noFill(plant.intermediate_reachSet,1,2,'b');
grid;
xlabel('x_1');
ylabel('x_2');

f5 = figure;
% rectangle('Position',[-1,-1,2,2],'FaceColor',[0 0.5 0 0.5],'EdgeColor','y', 'LineWidth',0.1)
hold on;
Star.plotBoxes_2D_noFill(plant.intermediate_reachSet,3,4,'b');
grid;
xlabel('x_3');
ylabel('x_4');

% Save figure
% if is_codeocean
%
% else
%
% end

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
function t = reach_set()

%% Reachability analysis of NAV Benchmark

%% Load Components

% Load the controller
net = importNetworkFromONNX('networks/nn-nav-set.onnx', "InputDataFormats", "BC");
net = matlab2nnv(net);
% Load plant
reachStep = 0.002;
controlPeriod = 0.02;
plant = NonLinearODE(4, 2, @dynamics, reachStep, controlPeriod, eye(4));
plant.set_tensorOrder(2);


%% Reachability analysis

% Initial set
lb = [2.9; 2.9; 0; 0];
ub = [3.1; 3.1; 0; 0];
init_set = Star(lb,ub);

% Store all reachable sets
reachAll = init_set;

% Reachability options
num_steps = 30;
reachOptions.reachMethod = 'approx-star';

% Begin computation
t = tic;
for i=1:num_steps
disp(i);

% Compute controller output set
input_set = net.reach(init_set,reachOptions);

% Compute plant reachable set
init_set = plant.stepReachStar(init_set, input_set,'lin');
reachAll = [reachAll init_set];
toc(t);

end

t = toc(t);


%% Visualize results
plant.get_interval_sets;

f2 = figure;
% rectangle('Position',[-1,-1,2,2],'FaceColor',[0 0.5 0 0.5],'EdgeColor','y', 'LineWidth',0.1)
hold on;
Star.plotBoxes_2D_noFill(plant.intermediate_reachSet,1,2,'b');
grid;
xlabel('x_1');
ylabel('x_2');

f5 = figure;
% rectangle('Position',[-1,-1,2,2],'FaceColor',[0 0.5 0 0.5],'EdgeColor','y', 'LineWidth',0.1)
hold on;
Star.plotBoxes_2D_noFill(plant.intermediate_reachSet,3,4,'b');
grid;
xlabel('x_3');
ylabel('x_4');

% Save figure
% if is_codeocean
%
% else
%
% end

end

0 comments on commit 09834f6

Please sign in to comment.