Skip to content

Commit

Permalink
Merge pull request #51 from cog-imperial/add-cnn-tests
Browse files Browse the repository at this point in the history
Test ConvLayer
  • Loading branch information
rmisener authored Feb 4, 2022
2 parents b15fb7b + ffbf211 commit 4fd2476
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion tests/neuralnet/test_relu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pyomo.environ as pyo
import numpy as np

from omlt.block import OmltBlock
from omlt.io.onnx import load_onnx_neural_network_with_bounds
from omlt.neuralnet import FullSpaceNNFormulation, ReluBigMFormulation, ReluComplementarityFormulation, ReluPartitionFormulation
from omlt.neuralnet.activations import ComplementarityReLUActivation

Expand Down Expand Up @@ -98,4 +100,33 @@ def test_two_node_ReluPartitionFormulation(two_node_network_relu):
m.neural_net_block.inputs[0].fix(1)
status = pyo.SolverFactory("cbc").solve(m, tee=False)
assert abs(pyo.value(m.neural_net_block.outputs[0,0]) - 1) < 1e-3
assert abs(pyo.value(m.neural_net_block.outputs[0,1]) - 0) < 1e-3
assert abs(pyo.value(m.neural_net_block.outputs[0,1]) - 0) < 1e-3

def test_conv_ReluBigMFormulation(datadir):
net = load_onnx_neural_network_with_bounds(datadir.file('keras_conv_7x7_relu.onnx'))
m = pyo.ConcreteModel()

m.neural_net_block = OmltBlock()
formulation = ReluBigMFormulation(net)
m.neural_net_block.build_formulation(formulation)
m.obj1 = pyo.Objective(expr=0)

# compute expected output for this input
input = np.eye(7, 7).reshape(1, 7, 7)
x = input
for layer in net.layers:
x = layer.eval(x)
output = x

for i in range(7):
for j in range(7):
m.neural_net_block.inputs[0, i, j].fix(input[0, i, j])
status = pyo.SolverFactory("cbc").solve(m, tee=False)

d, r, c = output.shape
for i in range(d):
for j in range(r):
for k in range(c):
expected = output[i, j, k]
actual = pyo.value(m.neural_net_block.outputs[i, j, k])
assert abs(actual - expected) < 1e-3

0 comments on commit 4fd2476

Please sign in to comment.