Skip to content

Commit

Permalink
Change ret value from single dict to multiple ret values
Browse files Browse the repository at this point in the history
Signed-off-by: Nathan Weinberg <[email protected]>
  • Loading branch information
nathan-weinberg committed Jun 17, 2024
1 parent 4e72d07 commit 11ad758
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 21 deletions.
21 changes: 6 additions & 15 deletions src/instructlab/eval/mmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,10 @@ def __init__(
self.few_shots = few_shots
self.batch_size = batch_size

def run(self) -> dict:
def run(self) -> tuple:
individual_scores: dict[str, float] = {}
overall_score: float = 0.0
payload = {
"individual_scores": individual_scores,
"overall_score": overall_score,
}
return payload
return overall_score, individual_scores


class PR_MMLU_Evaluator(Evaluator):
Expand All @@ -39,8 +35,8 @@ class PR_MMLU_Evaluator(Evaluator):
Attributes:
sdg_path path where all the PR MMLU tasks are stored
task group name that is shared by all the PR MMLU tasks
few_shots number of examples
batch_size number of GPUs
few_shots number of examples
batch_size number of GPUs
"""

def __init__(
Expand All @@ -57,13 +53,8 @@ def __init__(
self.few_shots = few_shots
self.batch_size = batch_size

def run(self) -> dict:
def run(self) -> tuple:
individual_scores: dict[str, float] = {}
overall_score: float = 0.0
qa_pairs: list[tuple] = []
payload = {
"individual_scores": individual_scores,
"overall_score": overall_score,
"qa_pairs": qa_pairs,
}
return payload
return overall_score, individual_scores, qa_pairs
10 changes: 4 additions & 6 deletions src/instructlab/eval/mtbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@ def __init__(self, model_path, server_url: str) -> None:
super().__init__(model_path)
self.server_url = server_url

def run(self) -> dict:
def run(self) -> tuple:
overall_score: float = 0.0
qa_pairs: list[tuple] = []
payload = {"overall_score": overall_score, "qa_pairs": qa_pairs}
return payload
return overall_score, qa_pairs


class PR_Bench_Evaluator(Evaluator):
Expand All @@ -37,8 +36,7 @@ def __init__(self, model_path, server_url: str, questions: str) -> None:
self.server_url = server_url
self.questions = questions

def run(self) -> dict:
def run(self) -> tuple:
overall_score = 0.0
qa_pairs: list[tuple] = []
payload = {"overall_score": overall_score, "qa_pairs": qa_pairs}
return payload
return overall_score, qa_pairs

0 comments on commit 11ad758

Please sign in to comment.