Skip to content

Commit

Permalink
Merge pull request xopt-org#222 from ChristopherMayes/fixed_features_…
Browse files Browse the repository at this point in the history
…improvements

add robustness and additional test to fixed features functionality in Bayesian generators
  • Loading branch information
roussel-ryan authored Apr 23, 2024
2 parents a613b00 + c32a23f commit 0b64028
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 20 deletions.
124 changes: 104 additions & 20 deletions docs/examples/single_objective_bayes_opt/fixed_features.ipynb

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions tests/generators/bayesian/test_bayesian_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,12 @@ def test_fixed_feature(self):
model.models[0].input_transform.bounds,
torch.tensor(((0, 0), (1, 10))).double(),
)

# test bad fixed feature name
gen = BayesianGenerator(vocs=TEST_VOCS_BASE)
gen.fixed_features = {"bad_name": 3.0}
data = deepcopy(TEST_VOCS_DATA)
gen.add_data(data)

with pytest.raises(KeyError):
gen.train_model()
8 changes: 8 additions & 0 deletions tests/test_vocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ def test_empty_objectives(self):
variables={"x": [0, 1]},
)

def test_output_names(self):
test_vocs = VOCS(
variables={"x": [0, 1]},
objectives={"y1": "MINIMIZE"},
constraints={"c1": ["GREATER_THAN", 0], "c2": ["LESS_THAN", 0]},
)
assert test_vocs.output_names == ["y1", "c1", "c2"]

def test_constraint_specification(self):
good_constraint_list = [
["LESS_THAN", 0],
Expand Down
5 changes: 5 additions & 0 deletions xopt/generators/bayesian/bayesian_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,11 @@ def train_model(self, data: pd.DataFrame = None, update_internal=True) -> Module
# get bounds for each fixed_feature (vocs bounds take precedent)
for key in self.fixed_features:
if key not in variable_bounds:
if key not in data:
raise KeyError(
"generator data needs to contain fixed feature "
f"column name `{key}`"
)
f_data = data[key]
bounds = [f_data.min(), f_data.max()]
if bounds[1] - bounds[0] < 1e-8:
Expand Down

0 comments on commit 0b64028

Please sign in to comment.