diff --git a/bofire/data_models/domain/features.py b/bofire/data_models/domain/features.py index ad57aaf0d..d4b21c4b8 100644 --- a/bofire/data_models/domain/features.py +++ b/bofire/data_models/domain/features.py @@ -125,17 +125,23 @@ def get_by_key(self, key: str) -> F: """ return {f.key: f for f in self.features}[key] - def get_by_keys(self, keys: Sequence[str]) -> Self: + def get_by_keys(self, keys: Sequence[str], include: bool = True) -> Self: """Get features of the domain specified by its keys. Args: keys: List of the keys of the features that should be returned. + include: Boolean to distinguish if the features with the keys in the + list should be included or excluded. Returns: Features: Features object with the requested features. """ - return self.__class__(features=sorted([self.get_by_key(key) for key in keys])) + if include: + features = [self.get_by_key(key) for key in keys] + else: + features = [f for f in self.features if f.key not in keys] + return self.__class__(features=sorted(features)) def get( self, diff --git a/tests/bofire/data_models/domain/test_features.py b/tests/bofire/data_models/domain/test_features.py index 1cd72c824..c37747d65 100644 --- a/tests/bofire/data_models/domain/test_features.py +++ b/tests/bofire/data_models/domain/test_features.py @@ -144,13 +144,20 @@ def test_features_get_by_key(features, key, expected): assert id(returned) == id(expected) -def test_features_get_by_keys(): +def test_features_get_by_keys_include(): keys = ["of2", "if1"] feats = features.get_by_keys(keys) assert feats[0].key == "if1" assert feats[1].key == "of2" +def test_features_get_by_keys_exclude(): + keys = ["of2", "if1"] + feats = features.get_by_keys(keys, include=False) + assert feats[0].key == "if2" + assert feats[1].key == "of1" + + @pytest.mark.parametrize( "features, key", [