Skip to content

Commit

Permalink
fix another test
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxiBoether committed Sep 26, 2024
1 parent 1d2cddb commit 404e0de
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def get_failure_reason(eval_aborted_reason: EvaluationAbortedReason) -> str:
# Will trigger a retry in case this is not successful
# Can happen, e.g., if the evaluator is overloaded
expected_num_metrics = len(request.metrics)
for abort_reason, data_dict in eval_results:
for result_id, (abort_reason, data_dict) in enumerate(eval_results):
if abort_reason is not None:
# If there was any reason to abort, we don't care
continue
Expand All @@ -446,10 +446,12 @@ def get_failure_reason(eval_aborted_reason: EvaluationAbortedReason) -> str:
), f"dataset size of 0, but no EMPTY_INTERVAL response: {eval_results}"
actual_num_metrics = len(data_dict["metrics"])
assert actual_num_metrics == expected_num_metrics, (
f"actual_num_metrics = {actual_num_metrics}"
f"result {result_id}: actual_num_metrics = {actual_num_metrics}"
+ f" != expected_num_metrics = {expected_num_metrics}"
+ "\n"
+ str(eval_results)
+ "\n\n"
+ str(eval_data)
)

# All checks succeeded
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
EvaluateModelResponse,
EvaluationAbortedReason,
EvaluationIntervalData,
SingleMetricResult,
)
from modyn.supervisor.internal.eval.strategies.abstract import EvalInterval
from modyn.supervisor.internal.eval.strategies.slicing import SlicingEvalStrategy
Expand Down Expand Up @@ -826,7 +827,12 @@ def get_eval_intervals(
],
)
]
test_get_evaluation_results.return_value = [EvaluationIntervalData() for _ in range(3)]
test_get_evaluation_results.return_value = [
EvaluationIntervalData(
interval_index=idx, evaluation_data=[SingleMetricResult(metric="Accuracy", result=0.5)]
)
for idx in [0, 2]
]

else:
intervals = [
Expand All @@ -851,7 +857,12 @@ def get_eval_intervals(
evaluation_id=42,
interval_responses=[success_interval for _ in range(len(intervals))],
)
test_get_evaluation_results.return_value = [EvaluationIntervalData() for _ in range(len(intervals))]
test_get_evaluation_results.return_value = [
EvaluationIntervalData(
interval_index=idx, evaluation_data=[SingleMetricResult(metric="Accuracy", result=0.5)]
)
for idx in range(len(intervals))
]

pe.grpc.evaluator = evaluator_stub_mock

Expand All @@ -869,7 +880,7 @@ def get_eval_intervals(

assert evaluator_stub_mock.evaluate_model.call_count == 1 # batched
if test_failure:
assert test_cleanup_evaluations.call_count == 1
assert test_cleanup_evaluations.call_count == 2
assert test_wait_for_evaluation_completion.call_count == 1

stage_info = [
Expand Down

0 comments on commit 404e0de

Please sign in to comment.