Skip to content

Commit

Permalink
Fix minor wrt 2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
goodsong81 committed Apr 12, 2024
1 parent 84e4d5a commit 41e5cca
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
12 changes: 9 additions & 3 deletions tests/perf/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,18 @@ def __call__(self, result_entry: pd.Series, target_entry: pd.Series) -> None:
return
if self.compare == "==":
print(
f"[Check] abs({result_entry[self.name]=} - {target_entry[self.name]=}) < {target_entry[self.name]=} * {self.margin=}",
f"[Check] abs({self.name}:{result_entry[self.name]} - {self.name}:{target_entry[self.name]}) < {self.name}:{target_entry[self.name]} * {self.margin}",
)
assert abs(result_entry[self.name] - target_entry[self.name]) < target_entry[self.name] * self.margin
elif self.compare == "<":
print(f"[Check] {result_entry[self.name]=} < {target_entry[self.name]=} * (1.0 + {self.margin=})")
print(
f"[Check] {self.name}:{result_entry[self.name]} < {self.name}:{target_entry[self.name]} * (1.0 + {self.margin})",
)
assert result_entry[self.name] < target_entry[self.name] * (1.0 + self.margin)
elif self.compare == ">":
print(f"[Check] {result_entry[self.name]=} > {target_entry[self.name]=} * (1.0 - {self.margin=})")
print(
f"[Check] {self.name}:{result_entry[self.name]} > {self.name}:{target_entry[self.name]} * (1.0 - {self.margin})",
)
assert result_entry[self.name] > target_entry[self.name] * (1.0 - self.margin)

def __init__(
Expand Down Expand Up @@ -287,6 +291,8 @@ def run(
gc.collect()

result = self.load_result(work_dir)
if result is None:
return None
result = summary.average(result, keys=["task", "model", "data_group", "data"]) # Average out seeds
return result.set_index(["task", "model", "data_group", "data"])

Expand Down
5 changes: 4 additions & 1 deletion tests/perf/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,10 @@ def fxt_model(request: pytest.FixtureRequest, fxt_model_category) -> Benchmark.M
model: Benchmark.Model = request.param
if fxt_model_category == "all":
return model
if (fxt_model_category == "default" and model.category == "other") or fxt_model_category != model.category:
if fxt_model_category == "default":
if model.category == "other":
pytest.skip(f"{model.category} category model")
elif fxt_model_category != model.category:
pytest.skip(f"{model.category} category model")
return model

Expand Down

0 comments on commit 41e5cca

Please sign in to comment.