diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 323066a..40a21b0 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -94,7 +94,7 @@ jobs: - name: Run e2e test run: | . venv/bin/activate - ./instructlab/scripts/basic-workflow-tests.sh -cm + ./instructlab/scripts/basic-workflow-tests.sh -m - name: Remove llama-cpp-python from cache if: always() diff --git a/src/instructlab/eval/mt_bench_common.py b/src/instructlab/eval/mt_bench_common.py index d346999..8e5dc9a 100644 --- a/src/instructlab/eval/mt_bench_common.py +++ b/src/instructlab/eval/mt_bench_common.py @@ -7,7 +7,6 @@ from typing import Optional import ast import dataclasses -import glob import json import os import re @@ -84,7 +83,7 @@ def load_questions(question_file: str, begin: Optional[int], end: Optional[int]) return questions -def load_model_answers(answer_dir: str, model_name=None) -> dict: +def load_model_answers(answer_dir: str, model_name=None, answer_file=None) -> dict: """Load model answers. The return value is a python dict of type: @@ -92,24 +91,36 @@ def load_model_answers(answer_dir: str, model_name=None) -> dict: """ logger.debug(locals()) model_answers = {} - for root, _, files in os.walk(answer_dir): - for filename in files: - if filename.endswith(".jsonl"): - # Removing ".jsonl" - file_model_name = filename[:-6] - answer = {} - file_path = os.path.join(root, filename) - with open(file_path, encoding="utf-8") as fin: - for line in fin: - l = json.loads(line) - answer[l["question_id"]] = l - model_answers[model_name or file_model_name] = answer - if model_name == file_model_name: - logger.debug("Found answer file matching: %s", model_name) - break + if answer_file is not None: + filename = os.path.basename(answer_file) + # Removing ".jsonl" + file_model_name = filename[:-6] + model_answers[file_model_name] = _load_answers(answer_file) + else: + for root, _, files in os.walk(answer_dir): + for filename in files: + if filename.endswith(".jsonl"): + # Removing ".jsonl" + file_model_name = filename[:-6] + file_path = os.path.join(root, filename) + model_answers[model_name or file_model_name] = _load_answers( + file_path + ) + if model_name == file_model_name: + logger.debug("Found answer file matching: %s", model_name) + break return model_answers +def _load_answers(answer_file): + answers = {} + with open(answer_file, encoding="utf-8") as fin: + for line in fin: + l = json.loads(line) + answers[l["question_id"]] = l + return answers + + def load_judge_prompts(prompt_file: str) -> dict: """Load judge prompts. @@ -304,8 +315,6 @@ def check_data(questions, model_answers, ref_answers, models, judges): ), f"Missing reference answer to Question {q['question_id']} for judge {jg.model_name}" -def get_model_list(answer_dir): +def get_model_list(answer_file): logger.debug(locals()) - file_paths = glob.glob(f"{answer_dir}/*.jsonl") - file_names = [os.path.splitext(os.path.basename(f))[0] for f in file_paths] - return file_names + return [os.path.splitext(os.path.basename(answer_file))[0]] diff --git a/src/instructlab/eval/mt_bench_judgment.py b/src/instructlab/eval/mt_bench_judgment.py index f7b40b8..29ef0e9 100644 --- a/src/instructlab/eval/mt_bench_judgment.py +++ b/src/instructlab/eval/mt_bench_judgment.py @@ -155,7 +155,6 @@ def judge_model( bench_name="mt_bench", output_dir="eval_output", data_dir=None, - model_list=None, max_workers=1, first_n=None, merge_system_user_message=False, @@ -180,7 +179,7 @@ def judge_model( questions = load_questions(question_file, None, None) # Load answers - model_answers = load_model_answers(answer_dir) + model_answers = load_model_answers(answer_dir, answer_file=answer_file) ref_answers = load_model_answers(ref_answer_dir, judge_model_name) # Load judge @@ -189,10 +188,7 @@ def judge_model( if first_n: questions = questions[:first_n] - if model_list is None: - models = get_model_list(answer_dir) - else: - models = model_list + models = get_model_list(answer_file) judges = make_judge_single(judge_model_name, judge_prompts) output_file = f"{output_base_dir}/model_judgment/{judge_model_name}_single.jsonl" @@ -280,7 +276,6 @@ def generate_judgment( output_dir="eval_output", data_dir=None, branch=None, - model_list=None, max_workers=1, first_n=None, merge_system_user_message=False, @@ -302,7 +297,6 @@ def generate_judgment( output_dir=output_dir, data_dir=data_dir, branch=branch, - model_list=model_list, max_workers=max_workers, first_n=first_n, merge_system_user_message=merge_system_user_message,