Skip to content

Commit

Permalink
Merge pull request #161 from AutoResearch/159-feat-bms-return-variabl…
Browse files Browse the repository at this point in the history
…e-number-of-models

159 feat bms return variable number of models
  • Loading branch information
TheLemonPig authored Dec 15, 2022
2 parents b18e8af + 387005e commit 8ac65a6
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ dmypy.json
site/

# Jupyter Notebook load data
.ipynb_checkpoints
.ipynb_checkpoints
2 changes: 2 additions & 0 deletions autora/skl/bms.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
self.X_: Optional[np.ndarray] = None
self.y_: Optional[np.ndarray] = None
self.model_: Tree = Tree()
self.models_: List[Tree] = [Tree()]
self.loss_: float = np.inf
self.cache_: List = []
self.variables: List = []
Expand Down Expand Up @@ -110,6 +111,7 @@ def fit(self, X: np.ndarray, y: np.ndarray, num_param: int = 1) -> BMSRegressor:
prior_par=self.prior_par,
)
self.model_, self.loss_, self.cache_ = utils.run(self.pms, self.epochs)
self.models_ = list(self.pms.trees.values())

_logger.info("BMS fitting finished")
self.X_, self.y_ = X, y
Expand Down
13 changes: 7 additions & 6 deletions autora/theorist/bms/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
Ts.sort()
self.Ts = [str(T) for T in Ts]
self.trees = {
"1": Tree(
"1.0": Tree(
ops=ops,
variables=deepcopy(variables),
parameters=deepcopy(parameters),
Expand All @@ -60,7 +60,7 @@ def __init__(
BT=1,
)
}
self.t1 = self.trees["1"]
self.t1 = self.trees["1.0"]
for BT in [T for T in self.Ts if T != 1]:
treetmp = Tree(
ops=ops,
Expand All @@ -87,7 +87,7 @@ def mcmc_step(self, verbose=False, p_rr=0.05, p_long=0.45) -> None:
for T, tree in list(self.trees.items()):
# MCMC step
tree.mcmc_step(verbose=verbose, p_rr=p_rr, p_long=p_long)
self.t1 = self.trees["1"]
self.t1 = self.trees["1.0"]

# -------------------------------------------------------------------------
def tree_swap(self) -> Tuple[Optional[str], Optional[str]]:
Expand Down Expand Up @@ -119,7 +119,7 @@ def tree_swap(self) -> Tuple[Optional[str], Optional[str]]:
self.trees[self.Ts[nT2]] = t1
t1.BT = BT2
t2.BT = BT1
self.t1 = self.trees["1"]
self.t1 = self.trees["1.0"]
return self.Ts[nT1], self.Ts[nT2]
else:
return None, None
Expand All @@ -140,7 +140,7 @@ def anneal(self, n=1000, factor=5) -> None:
t.BT *= factor
for kk in range(n):
print(
"# Annealing heating at %g: %d / %d" % (self.trees["1"].BT, kk, n),
"# Annealing heating at %g: %d / %d" % (self.trees["1.0"].BT, kk, n),
file=sys.stderr,
)
self.mcmc_step()
Expand All @@ -150,7 +150,8 @@ def anneal(self, n=1000, factor=5) -> None:
t.BT = float(BT)
for kk in range(2 * n):
print(
"# Annealing cooling at %g: %d / %d" % (self.trees["1"].BT, kk, 2 * n),
"# Annealing cooling at %g: %d / %d"
% (self.trees["1.0"].BT, kk, 2 * n),
file=sys.stderr,
)
self.mcmc_step()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_bms_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_tree_mcmc_stepping(
)

# MCMC
t.mcmc(burnin=200, thin=10, samples=samples, verbose=True)
t.mcmc(burnin=200, thin=10, samples=samples, verbose=False)

# Predict
print(t.predict(x))
Expand Down
25 changes: 25 additions & 0 deletions tests/test_bms_multi_model_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np
import pytest

from autora.skl.bms import BMSRegressor
from autora.theorist.bms import Tree


@pytest.fixture
def curve_to_fit():
x = np.linspace(-10, 10, 100).reshape(-1, 1)
y = (x**3.0) + (2.0 * x**2.0) + (17.0 * x) - 1
return x, y


def test_bms_models(curve_to_fit):
x, y = curve_to_fit
regressor = BMSRegressor(epochs=100)

regressor.fit(x, y)

print(regressor.models_)

assert len(regressor.models_) == len(regressor.ts) # Currently hardcoded
for model in regressor.models_:
assert isinstance(model, Tree)

0 comments on commit 8ac65a6

Please sign in to comment.