From b58cd78f1d61aa6f0ee6e0f851c4f7faccc853ce Mon Sep 17 00:00:00 2001 From: jeandut Date: Mon, 3 Jun 2024 13:01:47 +0000 Subject: [PATCH] trying to accomodate new API --- benchmark_utils/template_flamby_strategy.py | 6 ++++-- objective.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/benchmark_utils/template_flamby_strategy.py b/benchmark_utils/template_flamby_strategy.py index 4c00209..bc69638 100644 --- a/benchmark_utils/template_flamby_strategy.py +++ b/benchmark_utils/template_flamby_strategy.py @@ -93,8 +93,10 @@ def run(self, callback): # We are reproducing the run method but this time a callback checks # stopping-criterion at each round, which allows to cache computations # and do a single run - while callback(strat.models_list[0].model): + self.final_model = strat.models_list[0].model + while callback(): strat.perform_round() + self.final_model = strat.models_list[0].model self.final_model = strat.models_list[0].model @@ -103,7 +105,7 @@ def get_result(self): # The outputs of this function are the arguments of `Objective.compute` # This defines the benchmark's API for solvers' results. # it is customizable for each benchmark. - return self.final_model + return {"model": self.final_model} # Not used if callback is used @staticmethod diff --git a/objective.py b/objective.py index 580328c..70f6a07 100644 --- a/objective.py +++ b/objective.py @@ -256,7 +256,7 @@ def robust_metric(y_true, y_pred): def get_one_result(self): # Return one solution. The return value should be an object compatible # with `self.compute`. This is mainly for testing purposes. - return self.model_arch() + return dict(model=self.model_arch()) def get_objective(self): # Define the information to pass to each solver to run the benchmark.