Skip to content

Commit

Permalink
Add discrete cardinality constraint (#270)
Browse files Browse the repository at this point in the history
This PR is about adding discrete cardinality constraints: 

- Add `DiscreteCardinalityConstraint` class, which can be used in the
same way as the other existing discrete constraints. There are now two
possibilities to create a discrete subspace with cardinality
constraints: via `from_product` and via `from_simplex`, as how it was.

- Regarding imposing discrete cardinality constraint, using
`DiscreteCardinalityConstraint` + `from_product` gives us more
flexibility, since we can impose multiple cardinality constraints now
and we can specify the parameter space of each cardinality constraint.
However, `from_simplex` can be computationally more efficient.
  • Loading branch information
AdrianSosic authored Jun 27, 2024
2 parents 49350a3 + 2a6e5c6 commit 765dc4d
Show file tree
Hide file tree
Showing 16 changed files with 434 additions and 204 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
_ `_optional` subpackage for managing optional dependencies
- Acquisition function for active learning: `qNIPV`
- Abstract `ContinuousNonlinearConstraint` class
- `ContinuousCardinalityConstraint` class and corresponding uniform sampling mechanism
- Abstract `CardinalityConstraint` class and
`DiscreteCardinalityConstraint`/`ContinuousCardinalityConstraint` subclasses
- Uniform sampling mechanism for continuous spaces with cardinality constraints
- `register_hooks` utility enabling user-defined augmentation of arbitrary callables

### Changed
Expand Down
2 changes: 2 additions & 0 deletions baybe/constraints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from baybe.constraints.discrete import (
DISCRETE_CONSTRAINTS_FILTERING_ORDER,
DiscreteCardinalityConstraint,
DiscreteCustomConstraint,
DiscreteDependenciesConstraint,
DiscreteExcludeConstraint,
Expand All @@ -28,6 +29,7 @@
"ContinuousLinearEqualityConstraint",
"ContinuousLinearInequalityConstraint",
# --- Discrete constraints ---#
"DiscreteCardinalityConstraint",
"DiscreteCustomConstraint",
"DiscreteDependenciesConstraint",
"DiscreteExcludeConstraint",
Expand Down
74 changes: 67 additions & 7 deletions baybe/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import pandas as pd
from attr import define, field
from attr.validators import min_len
from attr.validators import ge, instance_of, min_len

from baybe.parameters import NumericalContinuousParameter
from baybe.serialization import (
Expand All @@ -26,11 +26,7 @@

@define
class Constraint(ABC, SerialMixin):
"""Abstract base class for all constraints.
Constraints use conditions and chain them together to filter unwanted entries from
the search space.
"""
"""Abstract base class for all constraints."""

# class variables
# TODO: it might turn out these are not needed at a later development stage
Expand All @@ -40,6 +36,10 @@ class Constraint(ABC, SerialMixin):
eval_during_modeling: ClassVar[bool]
"""Class variable encoding whether the condition is evaluated during modeling."""

numerical_only: ClassVar[bool] = False
"""Class variable encoding whether the constraint is valid only for numerical
parameters."""

# Object variables
parameters: list[str] = field(validator=min_len(1))
"""The list of parameters used for the constraint."""
Expand Down Expand Up @@ -116,6 +116,66 @@ class ContinuousConstraint(Constraint, ABC):
eval_during_modeling: ClassVar[bool] = True
# See base class.

numerical_only: ClassVar[bool] = True
# See base class.


@define
class CardinalityConstraint(Constraint, ABC):
"""Abstract base class for cardinality constraints.
Places a constraint on the set of nonzero (i.e. "active") values among the
specified parameters, bounding it between the two given integers,
``min_cardinality`` <= |{p_i : p_i != 0}| <= ``max_cardinality``
where ``{p_i}`` are the parameters specified for the constraint.
Note that this can be equivalently regarded as L0-constraint on the vector
containing the specified parameters.
"""

# class variable
numerical_only: ClassVar[bool] = True
# See base class.

# object variables
min_cardinality: int = field(default=0, validator=[instance_of(int), ge(0)])
"The minimum required cardinality."

max_cardinality: int = field(validator=instance_of(int))
"The maximum allowed cardinality."

@max_cardinality.default
def _default_max_cardinality(self):
"""Use the number of involved parameters as the upper limit by default."""
return len(self.parameters)

def __attrs_post_init__(self):
"""Validate the cardinality bounds.
Raises:
ValueError: If the provided cardinality bounds are invalid.
ValueError: If the provided cardinality bounds impose no constraint.
"""
if self.min_cardinality > self.max_cardinality:
raise ValueError(
f"The lower cardinality bound cannot be larger than the upper bound. "
f"Provided values: {self.max_cardinality=}, {self.min_cardinality=}."
)

if self.max_cardinality > len(self.parameters):
raise ValueError(
f"The cardinality bound cannot exceed the number of parameters. "
f"Provided values: {self.max_cardinality=}, {len(self.parameters)=}."
)

if self.min_cardinality == 0 and self.max_cardinality == len(self.parameters):
raise ValueError(
f"No constraint of type `{self.__class__.__name__}' is required "
f"when the lower cardinality bound is zero and the upper bound equals "
f"the number of parameters. Provided values: {self.min_cardinality=}, "
f"{self.max_cardinality=}, {len(self.parameters)=}"
)


@define
class ContinuousLinearConstraint(ContinuousConstraint, ABC):
Expand Down Expand Up @@ -206,7 +266,7 @@ def to_botorch(


class ContinuousNonlinearConstraint(ContinuousConstraint, ABC):
"""Abstract base class for nonlinear constraints."""
"""Abstract base class for continuous nonlinear constraints."""


# Register (un-)structure hooks
Expand Down
55 changes: 6 additions & 49 deletions baybe/constraints/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import math

import numpy as np
from attrs import define, field
from attrs.validators import ge, instance_of
from attrs import define

from baybe.constraints.base import (
CardinalityConstraint,
ContinuousLinearConstraint,
ContinuousNonlinearConstraint,
)
Expand Down Expand Up @@ -41,53 +41,10 @@ class ContinuousLinearInequalityConstraint(ContinuousLinearConstraint):


@define
class ContinuousCardinalityConstraint(ContinuousNonlinearConstraint):
"""Class for continuous cardinality constraints.
Places a constraint on the set of nonzero (i.e. "active") values among the
specified parameters, bounding it between the two given integers,
``min_cardinality`` <= |{p_i : p_i != 0}| <= ``max_cardinality``
where ``{p_i}`` are the parameters specified for the constraint.
Note that this can be equivalently regarded as L0-constraint on the vector
containing the specified parameters.
"""

min_cardinality: int = field(default=0, validator=[instance_of(int), ge(0)])
"The minimum required cardinality."

max_cardinality: int = field(validator=instance_of(int))
"The maximum allowed cardinality."

@max_cardinality.default
def _default_max_cardinality(self):
"""Use the number of involved parameters as the upper limit by default."""
return len(self.parameters)

def __attrs_post_init__(self):
"""Validate the cardinality bounds.
Raises:
ValueError: If the provided cardinality bounds are invalid.
ValueError: If the provided cardinality bounds impose no constraint.
"""
if self.min_cardinality > self.max_cardinality:
raise ValueError(
f"The lower cardinality bound cannot be larger than the upper bound. "
f"Provided values: {self.max_cardinality=}, {self.min_cardinality=}."
)

if self.max_cardinality > len(self.parameters):
raise ValueError(
f"The cardinality bound cannot exceed the number of parameters. "
f"Provided values: {self.max_cardinality=}, {len(self.parameters)=}."
)

if self.min_cardinality == 0 and self.max_cardinality == len(self.parameters):
raise ValueError(
f"No constraint of type `{self.__class__.__name__}' is required "
f"when 0 <= cardinality <= len(parameters)."
)
class ContinuousCardinalityConstraint(
CardinalityConstraint, ContinuousNonlinearConstraint
):
"""Class for continuous cardinality constraints."""

def sample_inactive_parameters(self, batch_size: int = 1) -> list[set[str]]:
"""Sample sets of inactive parameters according to the cardinality constraints.
Expand Down
36 changes: 30 additions & 6 deletions baybe/constraints/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from collections.abc import Callable
from functools import reduce
from typing import Any, cast
from typing import Any, ClassVar, cast

import pandas as pd
from attr import define, field
from attr.validators import in_, min_len

from baybe.constraints.base import DiscreteConstraint
from baybe.constraints.base import CardinalityConstraint, DiscreteConstraint
from baybe.constraints.conditions import (
Condition,
ThresholdCondition,
Expand Down Expand Up @@ -49,6 +49,10 @@ class DiscreteSumConstraint(DiscreteConstraint):

# IMPROVE: refactor `SumConstraint` and `ProdConstraint` to avoid code copying

# class variables
numerical_only: ClassVar[bool] = True
# see base class.

# object variables
condition: ThresholdCondition = field()
"""The condition modeled by this constraint."""
Expand All @@ -67,6 +71,10 @@ class DiscreteProductConstraint(DiscreteConstraint):

# IMPROVE: refactor `SumConstraint` and `ProdConstraint` to avoid code copying

# class variables
numerical_only: ClassVar[bool] = True
# see base class.

# object variables
condition: ThresholdCondition = field()
"""The condition that is used for this constraint."""
Expand Down Expand Up @@ -278,20 +286,36 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index: # noqa: D102
return data.index[mask_bad]


# the order in which the constraint types need to be applied during discrete subspace
# filtering
@define
class DiscreteCardinalityConstraint(CardinalityConstraint, DiscreteConstraint):
"""Class for discrete cardinality constraints."""

# Class variables
numerical_only: ClassVar[bool] = True
# See base class.

def get_invalid(self, data: pd.DataFrame) -> pd.Index: # noqa: D102
# See base class.
non_zeros = (data[self.parameters] != 0.0).sum(axis=1)
mask_bad = non_zeros > self.max_cardinality
mask_bad |= non_zeros < self.min_cardinality
return data.index[mask_bad]


# Constraints are approximately ordered according to increasing computational effort
# to minimize total time in their sequential application
DISCRETE_CONSTRAINTS_FILTERING_ORDER = (
DiscreteCustomConstraint,
DiscreteExcludeConstraint,
DiscreteNoLabelDuplicatesConstraint,
DiscreteLinkedParametersConstraint,
DiscreteSumConstraint,
DiscreteProductConstraint,
DiscreteCardinalityConstraint,
DiscreteCustomConstraint,
DiscretePermutationInvarianceConstraint,
DiscreteDependenciesConstraint,
)


# Prevent (de-)serialization of custom constraints
converter.register_unstructure_hook(DiscreteCustomConstraint, block_serialization_hook)
converter.register_structure_hook(DiscreteCustomConstraint, block_deserialization_hook)
18 changes: 17 additions & 1 deletion baybe/constraints/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

from baybe.constraints.base import Constraint
from baybe.constraints.continuous import ContinuousCardinalityConstraint
from baybe.constraints.discrete import DiscreteDependenciesConstraint
from baybe.constraints.discrete import (
DiscreteDependenciesConstraint,
)
from baybe.parameters.base import Parameter


Expand All @@ -22,6 +24,8 @@ def validate_constraints( # noqa: DOC101, DOC103
ValueError: If any constraint contains an invalid parameter name.
ValueError: If any continuous constraint includes a discrete parameter.
ValueError: If any discrete constraint includes a continuous parameter.
ValueError: If any discrete constraint that is valid only for numerical
discrete parameters includes non-numerical discrete parameters.
"""
if sum(isinstance(itm, DiscreteDependenciesConstraint) for itm in constraints) > 1:
raise ValueError(
Expand All @@ -36,6 +40,8 @@ def validate_constraints( # noqa: DOC101, DOC103
param_names_all = [p.name for p in parameters]
param_names_discrete = [p.name for p in parameters if p.is_discrete]
param_names_continuous = [p.name for p in parameters if p.is_continuous]
param_names_non_numerical = [p.name for p in parameters if not p.is_numerical]

for constraint in constraints:
if not all(p in param_names_all for p in constraint.parameters):
raise ValueError(
Expand All @@ -62,6 +68,16 @@ def validate_constraints( # noqa: DOC101, DOC103
f"{constraint.parameters}"
)

if constraint.numerical_only and any(
p in param_names_non_numerical for p in constraint.parameters
):
raise ValueError(
f"You are trying to initialize a constraint of type "
f"'{constraint.__class__.__name__}', which is valid only for numerical "
f"discrete parameters, over a non-numerical parameter. "
f"Parameter list of the affected constraint: {constraint.parameters}."
)


def validate_cardinality_constraints_are_nonoverlapping(
constraints: Collection[ContinuousCardinalityConstraint]
Expand Down
6 changes: 3 additions & 3 deletions baybe/searchspace/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __str__(self) -> str:
param_df = pd.DataFrame(param_list)
lin_eq_constr_df = pd.DataFrame(eq_constraints_list)
lin_ineq_constr_df = pd.DataFrame(ineq_constraints_list)
cardinality_constr_df = pd.DataFrame(nonlin_constraints_list)
nonlinear_constr_df = pd.DataFrame(nonlin_constraints_list)

# Put all attributes of the continuous class in one string
continuous_str = f"""{start_bold}Continuous Search Space{end_bold}
Expand All @@ -92,8 +92,8 @@ def __str__(self) -> str:
\r{pretty_print_df(lin_eq_constr_df)}
\n{start_bold}List of Linear Inequality Constraints{end_bold}
\r{pretty_print_df(lin_ineq_constr_df)}
\n{start_bold}List of Cardinality Constraints{end_bold}
\r{pretty_print_df(cardinality_constr_df)}"""
\n{start_bold}List of Nonlinear Constraints{end_bold}
\r{pretty_print_df(nonlinear_constr_df)}"""

return continuous_str.replace("\n", "\n ").replace("\r", "\r ")

Expand Down
12 changes: 12 additions & 0 deletions baybe/searchspace/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ def from_simplex(
ValueError: If the passed simplex parameters are not suitable for a simplex
construction.
ValueError: If the passed product parameters are not discrete.
ValueError: If the passed simplex parameters and product parameters are
not disjoint.
Returns:
The created simplex subspace.
Expand Down Expand Up @@ -388,6 +390,16 @@ def from_simplex(
f"must be of subclasses of '{DiscreteParameter.__name__}'."
)

# Validate no overlap between simplex parameters and product parameters
simplex_parameters_names = {p.name for p in simplex_parameters}
product_parameters_names = {p.name for p in product_parameters}
if overlap := simplex_parameters_names.intersection(product_parameters_names):
raise ValueError(
f"Parameter sets passed via 'simplex_parameters' and "
f"'product_parameters' must be disjoint but share the following "
f"parameters: {overlap}."
)

# Construct the product part of the space
product_space = parameter_cartesian_prod_to_df(product_parameters)
if not simplex_parameters:
Expand Down
Loading

0 comments on commit 765dc4d

Please sign in to comment.