From eefa4a2c066c748e2aee56edf4813d06f2eef113 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20P=2E=20D=C3=BCrholt?= Date: Thu, 30 Nov 2023 10:34:32 +0100 Subject: [PATCH] refactor polytope sampler --- bofire/strategies/samplers/polytope.py | 112 +++++++++++----------- tests/bofire/data_models/test_samplers.py | 3 +- 2 files changed, 57 insertions(+), 58 deletions(-) diff --git a/bofire/strategies/samplers/polytope.py b/bofire/strategies/samplers/polytope.py index 1a6a0d252..d2cdc1719 100644 --- a/bofire/strategies/samplers/polytope.py +++ b/bofire/strategies/samplers/polytope.py @@ -1,9 +1,11 @@ import warnings +from typing import Dict import numpy as np import pandas as pd import torch from botorch.optim.initializers import sample_q_batches_from_polytope +from botorch.optim.parameter_constraints import _generate_unfixed_lin_constraints from bofire.data_models.constraints.api import ( LinearEqualityConstraint, @@ -56,34 +58,47 @@ def _ask(self, n: int) -> pd.DataFrame: unit_scaled=False, ) cleaned_eqs = [] - pseudo_fixed = {} + fixed_features: Dict[str, float] = { + feat.key: feat.fixed_value()[0] # type: ignore + for feat in self.domain.inputs.get(ContinuousInput) + if feat.is_fixed() # type: ignore + } + for eq in eqs: if ( len(eq[0]) == 1 ): # only one coefficient, so this is a pseudo fixed feature - pseudo_fixed[ + fixed_features[ self.domain.inputs.get_keys(ContinuousInput)[eq[0][0]] ] = float(eq[2] / eq[1][0]) else: cleaned_eqs.append(eq) - # we have to map the indices in case of fixed features - # as we remove all fixed feature for the sampler, we have to adjust the - # indices in the constraints, here we get the mapper to map original - # to adjusted indices - feature_map = {} - counter = 0 - for i, feat in enumerate(self.domain.get_features(ContinuousInput)): - if (not feat.is_fixed()) and (feat.key not in pseudo_fixed.keys()): # type: ignore - feature_map[i] = counter - counter += 1 - - # get the bounds + fixed_features_indices: Dict[int, float] = { + self.domain.inputs.get_keys(ContinuousInput).index(key): value + for key, value in fixed_features.items() + } + + ineqs = get_linear_constraints( + domain=self.domain, + constraint=LinearInequalityConstraint, # type: ignore + unit_scaled=False, + ) + + interpoints = get_interpoint_constraints(domain=self.domain, n_candidates=n) + lower = [ feat.lower_bound # type: ignore for feat in self.domain.get_features(ContinuousInput) - if not feat.is_fixed() and feat.key not in pseudo_fixed.keys() # type: ignore + if feat.key not in fixed_features.keys() # type: ignore + ] + + upper = [ + feat.upper_bound # type: ignore + for feat in self.domain.get_features(ContinuousInput) + if feat.key not in fixed_features.keys() # type: ignore ] + if len(lower) == 0: warnings.warn( "Nothing to sample, all is fixed. Just the fixed set is returned.", @@ -93,50 +108,37 @@ def _ask(self, n: int) -> pd.DataFrame: data=np.nan, index=range(n), columns=self.domain.inputs.get_keys() ) else: - upper = [ - feat.upper_bound # type: ignore - for feat in self.domain.get_features(ContinuousInput) - if not feat.is_fixed() and feat.key not in pseudo_fixed.keys() # type: ignore - ] bounds = torch.tensor([lower, upper]).to(**tkwargs) - assert bounds.shape[-1] == len(feature_map) == counter - - # get the inequality constraints and map features back - # we also check that only features present in the mapper - # are present in the constraints - ineqs = get_linear_constraints( - domain=self.domain, - constraint=LinearInequalityConstraint, # type: ignore - unit_scaled=False, + + unfixed_ineqs = _generate_unfixed_lin_constraints( + constraints=ineqs, + eq=False, + fixed_features=fixed_features_indices, + dimension=len(self.domain.inputs.get(ContinuousInput)), + ) + unfixed_eqs = _generate_unfixed_lin_constraints( + constraints=cleaned_eqs, + eq=True, + fixed_features=fixed_features_indices, + dimension=len(self.domain.inputs.get(ContinuousInput)), + ) + unfixed_interpoints = _generate_unfixed_lin_constraints( + constraints=interpoints, + eq=True, + fixed_features=fixed_features_indices, + dimension=len(self.domain.inputs.get(ContinuousInput)), ) - for ineq in ineqs: - for key, value in feature_map.items(): - if key != value: - ineq[0][ineq[0] == key] = value - assert ( - ineq[0].max() <= counter - ), "Something went wrong when transforming the linear constraints. Revisit the problem." - - # TODO: check for pseudofixed - interpoints = get_interpoint_constraints(domain=self.domain, n_candidates=n) - - # map the indice of the equality constraints - for eq in cleaned_eqs: - for key, value in feature_map.items(): - if key != value: - eq[0][eq[0] == key] = value - assert ( - eq[0].max() <= counter - ), "Something went wrong when transforming the linear constraints. Revisit the problem." - - combined_eqs = interpoints + cleaned_eqs + + combined_eqs = unfixed_eqs + unfixed_interpoints # type: ignore # now use the hit and run sampler candidates = sample_q_batches_from_polytope( n=1, q=n, bounds=bounds.to(**tkwargs), - inequality_constraints=ineqs if len(ineqs) > 0 else None, + inequality_constraints=unfixed_ineqs + if len(unfixed_ineqs) > 0 # type: ignore + else None, equality_constraints=combined_eqs if len(combined_eqs) > 0 else None, n_burnin=self.n_burnin, thinning=self.n_thinning, @@ -150,7 +152,7 @@ def _ask(self, n: int) -> pd.DataFrame: free_continuals = [ feat.key for feat in self.domain.get_features(ContinuousInput) - if not feat.is_fixed() and feat.key not in pseudo_fixed.keys() # type: ignore + if feat.key not in fixed_features.keys() # type: ignore ] # setup the output @@ -165,11 +167,7 @@ def _ask(self, n: int) -> pd.DataFrame: samples[feat.key] = feat.sample(n) # type: ignore # setup the fixed continuous ones - for feat in self.domain.inputs.get_fixed(): - samples[feat.key] = feat.fixed_value()[0] # type: ignore - - # setup the pseudo fixed ones - for key, value in pseudo_fixed.items(): + for key, value in fixed_features.items(): samples[key] = value return samples diff --git a/tests/bofire/data_models/test_samplers.py b/tests/bofire/data_models/test_samplers.py index 5f5f56e85..51de9ee8d 100644 --- a/tests/bofire/data_models/test_samplers.py +++ b/tests/bofire/data_models/test_samplers.py @@ -103,6 +103,7 @@ def test_rejection_sampler_not_converged(): max_count=2, none_also_valid=False, ) +c7 = LinearEqualityConstraint(features=["if1", "if2"], coefficients=[1.0, 1.0], rhs=1.0) domains = [ Domain.from_lists(inputs=[if1, if2, if3], constraints=[c2]), @@ -167,7 +168,7 @@ def test_PolytopeSampler_all_fixed(): def test_PolytopeSampler_nchoosek(): domain = Domain.from_lists( inputs=[if1, if2, if3, if4, if6, If7], - constraints=[c6, c2], + constraints=[c6, c2, c7], ) data_model = data_models.PolytopeSampler(domain=domain) sampler = strategies.PolytopeSampler(data_model=data_model)