Skip to content

Commit

Permalink
Include qna_file in mt_bench_branch results
Browse files Browse the repository at this point in the history
Signed-off-by: Dan McPherson <[email protected]>
  • Loading branch information
danmcp committed Jun 27, 2024
1 parent f79ce58 commit 82fefc8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
28 changes: 17 additions & 11 deletions src/instructlab/eval/mt_bench_judgment.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,26 +105,32 @@ def make_judgment(
answer_df = pd.read_json(answer_file, lines=True)

# Join to get questions with answers
join_columns = ["question_id", "choices", "turns", "category"]
if bench_name == "mt_bench_branch":
join_columns.append("qna_file")

joined_df = question_df.join(
answer_df.set_index("question_id"), on="question_id", rsuffix="_answer"
)[["question_id", "choices", "turns", "category"]]
)[join_columns]
# Join to get scores
join_columns.append("score")
joined_df = judgment_df_all.join(
joined_df.set_index("question_id"), on="question_id", lsuffix="_judgment"
)[["question_id", "choices", "turns", "score", "category"]]
)[join_columns]
joined_df = joined_df[joined_df["score"] != -1]

qa_pairs = []
for _, row in joined_df.iterrows():
qa_pairs.append(
{
"question_id": row["question_id"],
"score": row["score"],
"category": row["category"],
"question": row["turns"],
"answer": row["choices"],
}
)
qa_pair = {
"question_id": row["question_id"],
"score": row["score"],
"category": row["category"],
"question": row["turns"],
"answer": row["choices"],
}
if bench_name == "mt_bench_branch":
qa_pair["qna_file"] = row["qna_file"]
qa_pairs.append(qa_pair)
return overall_score, qa_pairs, turn_scores


Expand Down
1 change: 1 addition & 0 deletions tests/test_branch_judge_answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
assert qa_pair.get("category") is not None
assert qa_pair.get("question") is not None
assert qa_pair.get("answer") is not None
assert qa_pair.get("qna_file") is not None

0 comments on commit 82fefc8

Please sign in to comment.