Skip to content

Commit

Permalink
refactor polytope sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
jduerholt committed Nov 30, 2023
1 parent 599f64a commit eefa4a2
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 58 deletions.
112 changes: 55 additions & 57 deletions bofire/strategies/samplers/polytope.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import warnings
from typing import Dict

import numpy as np
import pandas as pd
import torch
from botorch.optim.initializers import sample_q_batches_from_polytope
from botorch.optim.parameter_constraints import _generate_unfixed_lin_constraints

from bofire.data_models.constraints.api import (
LinearEqualityConstraint,
Expand Down Expand Up @@ -56,34 +58,47 @@ def _ask(self, n: int) -> pd.DataFrame:
unit_scaled=False,
)
cleaned_eqs = []
pseudo_fixed = {}
fixed_features: Dict[str, float] = {
feat.key: feat.fixed_value()[0] # type: ignore
for feat in self.domain.inputs.get(ContinuousInput)
if feat.is_fixed() # type: ignore
}

for eq in eqs:
if (
len(eq[0]) == 1
): # only one coefficient, so this is a pseudo fixed feature
pseudo_fixed[
fixed_features[
self.domain.inputs.get_keys(ContinuousInput)[eq[0][0]]
] = float(eq[2] / eq[1][0])
else:
cleaned_eqs.append(eq)

# we have to map the indices in case of fixed features
# as we remove all fixed feature for the sampler, we have to adjust the
# indices in the constraints, here we get the mapper to map original
# to adjusted indices
feature_map = {}
counter = 0
for i, feat in enumerate(self.domain.get_features(ContinuousInput)):
if (not feat.is_fixed()) and (feat.key not in pseudo_fixed.keys()): # type: ignore
feature_map[i] = counter
counter += 1

# get the bounds
fixed_features_indices: Dict[int, float] = {
self.domain.inputs.get_keys(ContinuousInput).index(key): value
for key, value in fixed_features.items()
}

ineqs = get_linear_constraints(
domain=self.domain,
constraint=LinearInequalityConstraint, # type: ignore
unit_scaled=False,
)

interpoints = get_interpoint_constraints(domain=self.domain, n_candidates=n)

lower = [
feat.lower_bound # type: ignore
for feat in self.domain.get_features(ContinuousInput)
if not feat.is_fixed() and feat.key not in pseudo_fixed.keys() # type: ignore
if feat.key not in fixed_features.keys() # type: ignore
]

upper = [
feat.upper_bound # type: ignore
for feat in self.domain.get_features(ContinuousInput)
if feat.key not in fixed_features.keys() # type: ignore
]

if len(lower) == 0:
warnings.warn(
"Nothing to sample, all is fixed. Just the fixed set is returned.",
Expand All @@ -93,50 +108,37 @@ def _ask(self, n: int) -> pd.DataFrame:
data=np.nan, index=range(n), columns=self.domain.inputs.get_keys()
)
else:
upper = [
feat.upper_bound # type: ignore
for feat in self.domain.get_features(ContinuousInput)
if not feat.is_fixed() and feat.key not in pseudo_fixed.keys() # type: ignore
]
bounds = torch.tensor([lower, upper]).to(**tkwargs)
assert bounds.shape[-1] == len(feature_map) == counter

# get the inequality constraints and map features back
# we also check that only features present in the mapper
# are present in the constraints
ineqs = get_linear_constraints(
domain=self.domain,
constraint=LinearInequalityConstraint, # type: ignore
unit_scaled=False,

unfixed_ineqs = _generate_unfixed_lin_constraints(
constraints=ineqs,
eq=False,
fixed_features=fixed_features_indices,
dimension=len(self.domain.inputs.get(ContinuousInput)),
)
unfixed_eqs = _generate_unfixed_lin_constraints(
constraints=cleaned_eqs,
eq=True,
fixed_features=fixed_features_indices,
dimension=len(self.domain.inputs.get(ContinuousInput)),
)
unfixed_interpoints = _generate_unfixed_lin_constraints(
constraints=interpoints,
eq=True,
fixed_features=fixed_features_indices,
dimension=len(self.domain.inputs.get(ContinuousInput)),
)
for ineq in ineqs:
for key, value in feature_map.items():
if key != value:
ineq[0][ineq[0] == key] = value
assert (
ineq[0].max() <= counter
), "Something went wrong when transforming the linear constraints. Revisit the problem."

# TODO: check for pseudofixed
interpoints = get_interpoint_constraints(domain=self.domain, n_candidates=n)

# map the indice of the equality constraints
for eq in cleaned_eqs:
for key, value in feature_map.items():
if key != value:
eq[0][eq[0] == key] = value
assert (
eq[0].max() <= counter
), "Something went wrong when transforming the linear constraints. Revisit the problem."

combined_eqs = interpoints + cleaned_eqs

combined_eqs = unfixed_eqs + unfixed_interpoints # type: ignore

# now use the hit and run sampler
candidates = sample_q_batches_from_polytope(
n=1,
q=n,
bounds=bounds.to(**tkwargs),
inequality_constraints=ineqs if len(ineqs) > 0 else None,
inequality_constraints=unfixed_ineqs
if len(unfixed_ineqs) > 0 # type: ignore
else None,
equality_constraints=combined_eqs if len(combined_eqs) > 0 else None,
n_burnin=self.n_burnin,
thinning=self.n_thinning,
Expand All @@ -150,7 +152,7 @@ def _ask(self, n: int) -> pd.DataFrame:
free_continuals = [
feat.key
for feat in self.domain.get_features(ContinuousInput)
if not feat.is_fixed() and feat.key not in pseudo_fixed.keys() # type: ignore
if feat.key not in fixed_features.keys() # type: ignore
]

# setup the output
Expand All @@ -165,11 +167,7 @@ def _ask(self, n: int) -> pd.DataFrame:
samples[feat.key] = feat.sample(n) # type: ignore

# setup the fixed continuous ones
for feat in self.domain.inputs.get_fixed():
samples[feat.key] = feat.fixed_value()[0] # type: ignore

# setup the pseudo fixed ones
for key, value in pseudo_fixed.items():
for key, value in fixed_features.items():
samples[key] = value

return samples
Expand Down
3 changes: 2 additions & 1 deletion tests/bofire/data_models/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def test_rejection_sampler_not_converged():
max_count=2,
none_also_valid=False,
)
c7 = LinearEqualityConstraint(features=["if1", "if2"], coefficients=[1.0, 1.0], rhs=1.0)

domains = [
Domain.from_lists(inputs=[if1, if2, if3], constraints=[c2]),
Expand Down Expand Up @@ -167,7 +168,7 @@ def test_PolytopeSampler_all_fixed():
def test_PolytopeSampler_nchoosek():
domain = Domain.from_lists(
inputs=[if1, if2, if3, if4, if6, If7],
constraints=[c6, c2],
constraints=[c6, c2, c7],
)
data_model = data_models.PolytopeSampler(domain=domain)
sampler = strategies.PolytopeSampler(data_model=data_model)
Expand Down

0 comments on commit eefa4a2

Please sign in to comment.