Skip to content

Commit

Permalink
Makes sure the names can computed from benchmark object
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgondu committed Aug 7, 2024
1 parent 2db0107 commit d55f0a8
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/poli/benchmarks/guacamol.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
evaluation_budget=evaluation_budget,
)

self.problem_factory_names = [
self.problem_factories = [
AlbuterolSimilarityProblemFactory(),
AmlodipineMPOProblemFactory(),
CelecoxibRediscoveryProblemFactory(),
Expand All @@ -122,7 +122,7 @@ def __init__(
self.string_representation = string_representation

def _initialize_problem(self, index: int) -> Problem:
problem_factory = self.problem_factory_names[index]
problem_factory = self.problem_factories[index]

problem = problem_factory.create(
string_representation=self.string_representation,
Expand Down
2 changes: 1 addition & 1 deletion src/poli/benchmarks/pmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
evaluation_budget=evaluation_budget,
)

self.problem_factory_names.extend(
self.problem_factories.extend(
[
JNK3ProblemFactory(),
GSK3BetaProblemFactory(),
Expand Down
16 changes: 6 additions & 10 deletions src/poli/benchmarks/toy_continuous_functions_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def __init__(
- set(SIX_DIMENSIONAL_PROBLEMS)
)
)
self.problem_factory_names = [ToyContinuousProblemFactory()] * len(
self.problem_factories = [ToyContinuousProblemFactory()] * len(
self.function_names
)

def _initialize_problem(self, index: int) -> Problem:
problem_factory: ToyContinuousProblemFactory = self.problem_factory_names[index]
problem_factory: ToyContinuousProblemFactory = self.problem_factories[index]

problem = problem_factory.create(
function_name=self.function_names[index],
Expand Down Expand Up @@ -131,12 +131,10 @@ def __init__(
evaluation_budget,
)
self.embed_in = [5, 10, 25, 50, 100]
self.problem_factory_names = [ToyContinuousProblemFactory()] * len(
self.embed_in
)
self.problem_factories = [ToyContinuousProblemFactory()] * len(self.embed_in)

def _initialize_problem(self, index: int) -> Problem:
problem_factory: ToyContinuousProblemFactory = self.problem_factory_names[index]
problem_factory: ToyContinuousProblemFactory = self.problem_factories[index]

problem = problem_factory.create(
function_name="branin_2d",
Expand Down Expand Up @@ -186,12 +184,10 @@ def __init__(
evaluation_budget,
)
self.embed_in = [None, 10, 25, 50, 100]
self.problem_factory_names = [ToyContinuousProblemFactory()] * len(
self.embed_in
)
self.problem_factories = [ToyContinuousProblemFactory()] * len(self.embed_in)

def _initialize_problem(self, index: int) -> Problem:
problem_factory: ToyContinuousProblemFactory = self.problem_factory_names[index]
problem_factory: ToyContinuousProblemFactory = self.problem_factories[index]

if index == 0:
problem = problem_factory.create(
Expand Down
12 changes: 7 additions & 5 deletions src/poli/core/abstract_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class AbstractBenchmark:
problem_factory_names: List[str] | List[AbstractProblemFactory]
problem_factories: List[AbstractProblemFactory]
index: int = 0

def __init__(
Expand All @@ -25,13 +25,13 @@ def __init__(
self.evaluation_budget = evaluation_budget

def __len__(self) -> int:
return len(self.problem_factory_names)
return len(self.problem_factories)

def __getitem__(self, index: int) -> Problem:
return self._initialize_problem(index)

def __next__(self) -> Problem:
if self.index < len(self.problem_factory_names):
if self.index < len(self.problem_factories):
self.index += 1
return self._initialize_problem(self.index - 1)
else:
Expand All @@ -48,6 +48,8 @@ def info(self) -> str:
@property
def problem_names(self) -> List[str]:
return [
problem_factory.get_problem_name()
for problem_factory in self.problem_factory_names
problem_factory.__module__.replace(
"poli.objective_repository.", ""
).replace(".register", "")
for problem_factory in self.problem_factories
]
31 changes: 31 additions & 0 deletions src/poli/tests/benchmarks/test_benchmark_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,37 @@ def test_creating_embedded_toy_continuous_functions_benchmark():
f(x0)


def test_names_from_guacamol_benchmark():
from poli.benchmarks import GuacaMolGoalDirectedBenchmark

assert GuacaMolGoalDirectedBenchmark(
string_representation="SMILES"
).problem_names == [
"albuterol_similarity",
"amlodipine_mpo",
"celecoxib_rediscovery",
"deco_hop",
"fexofenadine_mpo",
"isomer_c7h8n2o2",
"isomer_c9h10n2o2pf2cl",
"median_1",
"median_2",
"mestranol_similarity",
"osimetrinib_mpo",
"perindopril_mpo",
"ranolazine_mpo",
"rdkit_logp",
"rdkit_qed",
"sa_tdc",
"scaffold_hop",
"sitagliptin_mpo",
"thiothixene_rediscovery",
"troglitazone_rediscovery",
"valsartan_smarts",
"zaleplon_mpo",
]


@pytest.mark.poli__tdc
def test_creating_guacamol_benchmark():
from poli.benchmarks import GuacaMolGoalDirectedBenchmark
Expand Down

0 comments on commit d55f0a8

Please sign in to comment.