diff --git a/src/poli/benchmarks/guacamol.py b/src/poli/benchmarks/guacamol.py index 4bf0a198..0dd45f03 100644 --- a/src/poli/benchmarks/guacamol.py +++ b/src/poli/benchmarks/guacamol.py @@ -95,7 +95,7 @@ def __init__( evaluation_budget=evaluation_budget, ) - self.problem_factory_names = [ + self.problem_factories = [ AlbuterolSimilarityProblemFactory(), AmlodipineMPOProblemFactory(), CelecoxibRediscoveryProblemFactory(), @@ -122,7 +122,7 @@ def __init__( self.string_representation = string_representation def _initialize_problem(self, index: int) -> Problem: - problem_factory = self.problem_factory_names[index] + problem_factory = self.problem_factories[index] problem = problem_factory.create( string_representation=self.string_representation, diff --git a/src/poli/benchmarks/pmo.py b/src/poli/benchmarks/pmo.py index cf6f1e18..6652141a 100644 --- a/src/poli/benchmarks/pmo.py +++ b/src/poli/benchmarks/pmo.py @@ -86,7 +86,7 @@ def __init__( evaluation_budget=evaluation_budget, ) - self.problem_factory_names.extend( + self.problem_factories.extend( [ JNK3ProblemFactory(), GSK3BetaProblemFactory(), diff --git a/src/poli/benchmarks/toy_continuous_functions_benchmark.py b/src/poli/benchmarks/toy_continuous_functions_benchmark.py index 24052887..6cf2fd78 100644 --- a/src/poli/benchmarks/toy_continuous_functions_benchmark.py +++ b/src/poli/benchmarks/toy_continuous_functions_benchmark.py @@ -73,12 +73,12 @@ def __init__( - set(SIX_DIMENSIONAL_PROBLEMS) ) ) - self.problem_factory_names = [ToyContinuousProblemFactory()] * len( + self.problem_factories = [ToyContinuousProblemFactory()] * len( self.function_names ) def _initialize_problem(self, index: int) -> Problem: - problem_factory: ToyContinuousProblemFactory = self.problem_factory_names[index] + problem_factory: ToyContinuousProblemFactory = self.problem_factories[index] problem = problem_factory.create( function_name=self.function_names[index], @@ -131,12 +131,10 @@ def __init__( evaluation_budget, ) self.embed_in = [5, 10, 25, 50, 100] - self.problem_factory_names = [ToyContinuousProblemFactory()] * len( - self.embed_in - ) + self.problem_factories = [ToyContinuousProblemFactory()] * len(self.embed_in) def _initialize_problem(self, index: int) -> Problem: - problem_factory: ToyContinuousProblemFactory = self.problem_factory_names[index] + problem_factory: ToyContinuousProblemFactory = self.problem_factories[index] problem = problem_factory.create( function_name="branin_2d", @@ -186,12 +184,10 @@ def __init__( evaluation_budget, ) self.embed_in = [None, 10, 25, 50, 100] - self.problem_factory_names = [ToyContinuousProblemFactory()] * len( - self.embed_in - ) + self.problem_factories = [ToyContinuousProblemFactory()] * len(self.embed_in) def _initialize_problem(self, index: int) -> Problem: - problem_factory: ToyContinuousProblemFactory = self.problem_factory_names[index] + problem_factory: ToyContinuousProblemFactory = self.problem_factories[index] if index == 0: problem = problem_factory.create( diff --git a/src/poli/core/abstract_benchmark.py b/src/poli/core/abstract_benchmark.py index 4656ed79..5204a3c2 100644 --- a/src/poli/core/abstract_benchmark.py +++ b/src/poli/core/abstract_benchmark.py @@ -7,7 +7,7 @@ class AbstractBenchmark: - problem_factory_names: List[str] | List[AbstractProblemFactory] + problem_factories: List[AbstractProblemFactory] index: int = 0 def __init__( @@ -25,13 +25,13 @@ def __init__( self.evaluation_budget = evaluation_budget def __len__(self) -> int: - return len(self.problem_factory_names) + return len(self.problem_factories) def __getitem__(self, index: int) -> Problem: return self._initialize_problem(index) def __next__(self) -> Problem: - if self.index < len(self.problem_factory_names): + if self.index < len(self.problem_factories): self.index += 1 return self._initialize_problem(self.index - 1) else: @@ -48,6 +48,8 @@ def info(self) -> str: @property def problem_names(self) -> List[str]: return [ - problem_factory.get_problem_name() - for problem_factory in self.problem_factory_names + problem_factory.__module__.replace( + "poli.objective_repository.", "" + ).replace(".register", "") + for problem_factory in self.problem_factories ] diff --git a/src/poli/tests/benchmarks/test_benchmark_creation.py b/src/poli/tests/benchmarks/test_benchmark_creation.py index 3e67ef5f..37c8ee91 100644 --- a/src/poli/tests/benchmarks/test_benchmark_creation.py +++ b/src/poli/tests/benchmarks/test_benchmark_creation.py @@ -29,6 +29,37 @@ def test_creating_embedded_toy_continuous_functions_benchmark(): f(x0) +def test_names_from_guacamol_benchmark(): + from poli.benchmarks import GuacaMolGoalDirectedBenchmark + + assert GuacaMolGoalDirectedBenchmark( + string_representation="SMILES" + ).problem_names == [ + "albuterol_similarity", + "amlodipine_mpo", + "celecoxib_rediscovery", + "deco_hop", + "fexofenadine_mpo", + "isomer_c7h8n2o2", + "isomer_c9h10n2o2pf2cl", + "median_1", + "median_2", + "mestranol_similarity", + "osimetrinib_mpo", + "perindopril_mpo", + "ranolazine_mpo", + "rdkit_logp", + "rdkit_qed", + "sa_tdc", + "scaffold_hop", + "sitagliptin_mpo", + "thiothixene_rediscovery", + "troglitazone_rediscovery", + "valsartan_smarts", + "zaleplon_mpo", + ] + + @pytest.mark.poli__tdc def test_creating_guacamol_benchmark(): from poli.benchmarks import GuacaMolGoalDirectedBenchmark