Skip to content

Commit

Permalink
Merge pull request #112 from t-bz/fix_EI_acquisition
Browse files Browse the repository at this point in the history
Fix acquisition for ExpectedImprovementGenerator
  • Loading branch information
roussel-ryan authored May 3, 2023
2 parents 9ee1faf + 7fbadde commit f976f4f
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 18 deletions.
42 changes: 39 additions & 3 deletions tests/generators/bayesian/test_expected_improvement.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import torch
import pandas as pd
from copy import deepcopy

import pytest
from botorch.acquisition import ExpectedImprovement

from xopt.base import Xopt

from xopt.vocs import VOCS, ObjectiveEnum
from xopt.evaluator import Evaluator
from xopt.generators.bayesian.expected_improvement import ExpectedImprovementGenerator
from xopt.generators.bayesian.upper_confidence_bound import UCBOptions

from xopt.resources.testing import TEST_VOCS_BASE, TEST_VOCS_DATA, xtest_callable


Expand Down Expand Up @@ -82,3 +83,38 @@ def test_in_xopt_w_proximal(self):
# now use bayes opt
for _ in range(1):
xopt.step()

def test_acquisition_accuracy(self):
train_x = torch.tensor([0.01, 0.3, 0.6, 0.99]).double()
train_y = torch.sin(2 * torch.pi * train_x)
train_data = pd.DataFrame(
{"x1": train_x.numpy(), "y1": train_y.numpy()})
test_x = torch.linspace(0.0, 1.0, 1000)

for objective in ObjectiveEnum:
vocs = VOCS(**{"variables": {"x1": [0.0, 1.0]},
"objectives": {"y1": objective}})
generator = ExpectedImprovementGenerator(vocs)
generator.add_data(train_data)
model = generator.train_model().models[0]

# xopt acquisition function
acq = generator.get_acquisition(model)

# analytical acquisition function
if objective == "MAXIMIZE":
an_acq = ExpectedImprovement(model, best_f=train_y.max(),
maximize=True)
else:
an_acq = ExpectedImprovement(model, best_f=train_y.min(),
maximize=False)

# compare candidates (maximum in test data)
with torch.no_grad():
acq_v = acq(test_x.reshape(-1, 1, 1))
candidate = test_x[torch.argmax(acq_v)]
an_acq_v = an_acq(test_x.reshape(-1, 1, 1))
an_candidate = test_x[torch.argmax(an_acq_v)]

# difference should be small
assert torch.abs(an_candidate - candidate) < 0.01
21 changes: 13 additions & 8 deletions xopt/generators/bayesian/custom_botorch/constrained_acqusition.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,17 @@ def __init__(
@concatenate_pending_points
@t_batch_mode_transform()
def forward(self, X: Tensor) -> Tensor:
posterior = self.model.posterior(
X=X, posterior_transform=self.posterior_transform
)
samples = self.get_posterior_samples(posterior)
obj = self.objective(samples, X=X)
if self.objective.constraints:
posterior = self.model.posterior(
X=X, posterior_transform=self.posterior_transform
)
samples = self.get_posterior_samples(posterior)
obj = self.objective(samples, X=X)

# multiply the output of the base acquisition function by the feasibility
base_val = torch.nn.functional.softplus(self.base_acqusition(X), beta=10)
return base_val * obj.max(dim=-1)[0].mean(dim=0)
# multiply the output of the base acquisition function by
# the feasibility
base_val = torch.nn.functional.softplus(
self.base_acqusition(X), beta=10)
return base_val * obj.max(dim=-1)[0].mean(dim=0)
else:
return self.base_acqusition(X)
9 changes: 2 additions & 7 deletions xopt/generators/bayesian/expected_improvement.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pandas as pd
import torch
from botorch.acquisition import qExpectedImprovement

Expand Down Expand Up @@ -45,12 +44,8 @@ def default_options() -> BayesianOptions:
return BayesianOptions()

def _get_acquisition(self, model):
valid_data = self.data[
pd.unique(self.vocs.variable_names + self.vocs.output_names)
].dropna()
objective_data = self.vocs.objective_data(valid_data, "")

best_f = torch.tensor(objective_data.max(), **self._tkwargs)
objective_data = self.vocs.objective_data(self.data, "").dropna()
best_f = -torch.tensor(objective_data.min(), **self._tkwargs)

qEI = qExpectedImprovement(
model,
Expand Down

0 comments on commit f976f4f

Please sign in to comment.