From 404e0de2e9f3c9fcff0d4a548162bc6bfcc4f485 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20B=C3=B6ther?= Date: Thu, 26 Sep 2024 15:04:06 +0400 Subject: [PATCH] fix another test --- .../pipeline_executor/evaluation_executor.py | 6 ++++-- .../pipeline_executor/test_pipeline_executor.py | 17 ++++++++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py b/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py index cb87fac43..75333d46f 100644 --- a/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py +++ b/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py @@ -436,7 +436,7 @@ def get_failure_reason(eval_aborted_reason: EvaluationAbortedReason) -> str: # Will trigger a retry in case this is not successful # Can happen, e.g., if the evaluator is overloaded expected_num_metrics = len(request.metrics) - for abort_reason, data_dict in eval_results: + for result_id, (abort_reason, data_dict) in enumerate(eval_results): if abort_reason is not None: # If there was any reason to abort, we don't care continue @@ -446,10 +446,12 @@ def get_failure_reason(eval_aborted_reason: EvaluationAbortedReason) -> str: ), f"dataset size of 0, but no EMPTY_INTERVAL response: {eval_results}" actual_num_metrics = len(data_dict["metrics"]) assert actual_num_metrics == expected_num_metrics, ( - f"actual_num_metrics = {actual_num_metrics}" + f"result {result_id}: actual_num_metrics = {actual_num_metrics}" + f" != expected_num_metrics = {expected_num_metrics}" + "\n" + str(eval_results) + + "\n\n" + + str(eval_data) ) # All checks succeeded diff --git a/modyn/tests/supervisor/internal/pipeline_executor/test_pipeline_executor.py b/modyn/tests/supervisor/internal/pipeline_executor/test_pipeline_executor.py index d05a39d73..a4f870ee3 100644 --- a/modyn/tests/supervisor/internal/pipeline_executor/test_pipeline_executor.py +++ b/modyn/tests/supervisor/internal/pipeline_executor/test_pipeline_executor.py @@ -22,6 +22,7 @@ EvaluateModelResponse, EvaluationAbortedReason, EvaluationIntervalData, + SingleMetricResult, ) from modyn.supervisor.internal.eval.strategies.abstract import EvalInterval from modyn.supervisor.internal.eval.strategies.slicing import SlicingEvalStrategy @@ -826,7 +827,12 @@ def get_eval_intervals( ], ) ] - test_get_evaluation_results.return_value = [EvaluationIntervalData() for _ in range(3)] + test_get_evaluation_results.return_value = [ + EvaluationIntervalData( + interval_index=idx, evaluation_data=[SingleMetricResult(metric="Accuracy", result=0.5)] + ) + for idx in [0, 2] + ] else: intervals = [ @@ -851,7 +857,12 @@ def get_eval_intervals( evaluation_id=42, interval_responses=[success_interval for _ in range(len(intervals))], ) - test_get_evaluation_results.return_value = [EvaluationIntervalData() for _ in range(len(intervals))] + test_get_evaluation_results.return_value = [ + EvaluationIntervalData( + interval_index=idx, evaluation_data=[SingleMetricResult(metric="Accuracy", result=0.5)] + ) + for idx in range(len(intervals)) + ] pe.grpc.evaluator = evaluator_stub_mock @@ -869,7 +880,7 @@ def get_eval_intervals( assert evaluator_stub_mock.evaluate_model.call_count == 1 # batched if test_failure: - assert test_cleanup_evaluations.call_count == 1 + assert test_cleanup_evaluations.call_count == 2 assert test_wait_for_evaluation_completion.call_count == 1 stage_info = [