diff --git a/tests/test_bms_multi_model_output.py b/tests/test_bms_multi_model_output.py new file mode 100644 index 000000000..7ea5d36ac --- /dev/null +++ b/tests/test_bms_multi_model_output.py @@ -0,0 +1,25 @@ +import numpy as np +import pytest +from theorist.bms import Tree + +from autora.skl.bms import BMSRegressor + + +@pytest.fixture +def curve_to_fit(): + x = np.linspace(-10, 10, 100).reshape(-1, 1) + y = (x**3.0) + (2.0 * x**2.0) + (17.0 * x) - 1 + return x, y + + +def test_bms_models(curve_to_fit): + x, y = curve_to_fit + regressor = BMSRegressor(epochs=100) + + regressor.fit(x, y) + + print(regressor.models_) + + assert len(regressor.models_) == len(regressor.ts) # Currently hardcoded + for model in regressor.models_: + assert isinstance(model, Tree)