From 7d660c4e075de4156d4b58faf4aeecab0d313df0 Mon Sep 17 00:00:00 2001 From: Jeremy Sadler <53983960+jezsadler@users.noreply.github.com> Date: Mon, 6 Nov 2023 20:28:26 +0000 Subject: [PATCH] Hopefully final linting --- src/omlt/neuralnet/layer.py | 2 +- tests/io/test_onnx_parser.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/omlt/neuralnet/layer.py b/src/omlt/neuralnet/layer.py index fa9951f5..d7f7fa89 100644 --- a/src/omlt/neuralnet/layer.py +++ b/src/omlt/neuralnet/layer.py @@ -377,7 +377,7 @@ def __init__( activation=activation, input_index_mapper=input_index_mapper, ) - if pool_func_name not in PoolingLayer2D._POOL_FUNCTIONS: + if pool_func_name not in PoolingLayer2D._POOL_FUNCTIONS: raise ValueError( f"Allowable pool functions are {PoolingLayer2D._POOL_FUNCTIONS}, {pool_func_name} was provided." ) diff --git a/tests/io/test_onnx_parser.py b/tests/io/test_onnx_parser.py index 4dec3f90..763b282c 100644 --- a/tests/io/test_onnx_parser.py +++ b/tests/io/test_onnx_parser.py @@ -195,7 +195,9 @@ def test_consume_dense_wrong_dims(datadir): parser = NetworkParser() parser.parse_network(model.graph, None, None) - parser._nodes["StatefulPartitionedCall/keras_linear_131/dense/MatMul"][1].input.append("abcd") + parser._nodes["StatefulPartitionedCall/keras_linear_131/dense/MatMul"][ + 1 + ].input.append("abcd") with pytest.raises(ValueError) as excinfo: parser._consume_dense_nodes( parser._nodes["StatefulPartitionedCall/keras_linear_131/dense/MatMul"][1], @@ -254,8 +256,6 @@ def test_consume_maxpool_wrong_dims(datadir): parser.parse_network(model.graph, None, None) parser._nodes["node1"][1].input.append("abcd") with pytest.raises(ValueError) as excinfo: - parser._consume_pool_nodes( - parser._nodes["node1"][1], parser._nodes["node1"][2] - ) + parser._consume_pool_nodes(parser._nodes["node1"][1], parser._nodes["node1"][2]) expected_msg_maxpool = """node1 input has 2 dimensions, only nodes with 1 input dimension can be used as starting points for consumption.""" assert str(excinfo.value) == expected_msg_maxpool