From 348543eebcc9fea28f573db6edbf042a04a0b4a2 Mon Sep 17 00:00:00 2001 From: John Gerrard Holland Date: Tue, 6 Dec 2022 16:47:52 -0500 Subject: [PATCH 1/2] test: add testcase for models_ --- tests/test_bms_multi_model_output.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 tests/test_bms_multi_model_output.py diff --git a/tests/test_bms_multi_model_output.py b/tests/test_bms_multi_model_output.py new file mode 100644 index 000000000..5e427cfd4 --- /dev/null +++ b/tests/test_bms_multi_model_output.py @@ -0,0 +1,25 @@ +import numpy as np +import pytest +from theorist.bms import Tree + +from autora.skl.bms import BMSRegressor + + +@pytest.fixture +def curve_to_fit(): + x = np.linspace(-10, 10, 100).reshape(-1, 1) + y = (x**3.0) + (2.0 * x**2.0) + (17.0 * x) - 1 + return x, y + + +def test_bms_models(curve_to_fit): + x, y = curve_to_fit + regressor = BMSRegressor(epochs=100) + + regressor.fit(x, y) + + print(regressor.models_) + + assert len(regressor.models_) == 20 # Currently hardcoded + for model in regressor.models_: + assert isinstance(model, Tree) From 5d80834891e34956e02b958387ca23b92c0c0d0a Mon Sep 17 00:00:00 2001 From: TheLemonPig <43144407+TheLemonPig@users.noreply.github.com> Date: Wed, 7 Dec 2022 16:03:58 -0500 Subject: [PATCH 2/2] Update tests/test_bms_multi_model_output.py --- tests/test_bms_multi_model_output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_bms_multi_model_output.py b/tests/test_bms_multi_model_output.py index 5e427cfd4..7ea5d36ac 100644 --- a/tests/test_bms_multi_model_output.py +++ b/tests/test_bms_multi_model_output.py @@ -20,6 +20,6 @@ def test_bms_models(curve_to_fit): print(regressor.models_) - assert len(regressor.models_) == 20 # Currently hardcoded + assert len(regressor.models_) == len(regressor.ts) # Currently hardcoded for model in regressor.models_: assert isinstance(model, Tree)