diff --git a/src/instructlab/eval/mt_bench_branch_generator.py b/src/instructlab/eval/mt_bench_branch_generator.py index 837f5c0..bff984e 100644 --- a/src/instructlab/eval/mt_bench_branch_generator.py +++ b/src/instructlab/eval/mt_bench_branch_generator.py @@ -66,7 +66,7 @@ def generate(judge_model_name, branch, taxonomy_dir, output_dir): t_1 = q # Generate a consistent hash to have consistent question_id across qna_files from different runs - str_bytes = bytes(qna_file, "UTF-8") + str_bytes = bytes(q, "UTF-8") m = hashlib.md5(str_bytes) question_id = str(int(m.hexdigest(), base=16)) question_lst.append( diff --git a/src/instructlab/eval/mt_bench_judgment.py b/src/instructlab/eval/mt_bench_judgment.py index 71c7351..c4e6494 100644 --- a/src/instructlab/eval/mt_bench_judgment.py +++ b/src/instructlab/eval/mt_bench_judgment.py @@ -105,26 +105,32 @@ def make_judgment( answer_df = pd.read_json(answer_file, lines=True) # Join to get questions with answers + join_columns = ["question_id", "choices", "turns", "category"] + if bench_name == "mt_bench_branch": + join_columns.append("qna_file") + joined_df = question_df.join( answer_df.set_index("question_id"), on="question_id", rsuffix="_answer" - )[["question_id", "choices", "turns", "category"]] + )[join_columns] # Join to get scores + join_columns.append("score") joined_df = judgment_df_all.join( joined_df.set_index("question_id"), on="question_id", lsuffix="_judgment" - )[["question_id", "choices", "turns", "score", "category"]] + )[join_columns] joined_df = joined_df[joined_df["score"] != -1] qa_pairs = [] for _, row in joined_df.iterrows(): - qa_pairs.append( - { - "question_id": row["question_id"], - "score": row["score"], - "category": row["category"], - "question": row["turns"], - "answer": row["choices"], - } - ) + qa_pair = { + "question_id": row["question_id"], + "score": row["score"], + "category": row["category"], + "question": row["turns"], + "answer": row["choices"], + } + if bench_name == "mt_bench_branch": + qa_pair["qna_file"] = row["qna_file"] + qa_pairs.append(qa_pair) return overall_score, qa_pairs, turn_scores diff --git a/tests/test_branch_judge_answers.py b/tests/test_branch_judge_answers.py index ada55f0..6120c00 100644 --- a/tests/test_branch_judge_answers.py +++ b/tests/test_branch_judge_answers.py @@ -23,3 +23,4 @@ assert qa_pair.get("category") is not None assert qa_pair.get("question") is not None assert qa_pair.get("answer") is not None + assert qa_pair.get("qna_file") is not None