From 3445ce0bdffdce7246a86bbb47e2a618f78d72c1 Mon Sep 17 00:00:00 2001 From: sallyom Date: Wed, 11 Sep 2024 12:44:37 -0400 Subject: [PATCH] Add option to pass 'api_key' to mt_bench gen_answers, judge_answers `api_key` is optional and this PR remains backwards compatible. This allows for externally served models that require authentication. A helper function is added in mt_bench_common for creating the openai_client necessary for model requests. Signed-off-by: sallyom --- src/instructlab/eval/mt_bench.py | 16 ++++++++++++---- src/instructlab/eval/mt_bench_answers.py | 6 ++++-- src/instructlab/eval/mt_bench_common.py | 7 +++++++ src/instructlab/eval/mt_bench_judgment.py | 6 ++++-- 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/instructlab/eval/mt_bench.py b/src/instructlab/eval/mt_bench.py index 240234b..ea2ff23 100644 --- a/src/instructlab/eval/mt_bench.py +++ b/src/instructlab/eval/mt_bench.py @@ -94,27 +94,30 @@ class MTBenchEvaluator(AbstractMTBenchEvaluator): name = "mt_bench" - def gen_answers(self, server_url) -> None: + def gen_answers(self, server_url, api_key: str | None = None) -> None: """ Asks questions to model Attributes server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated + api_key API token for authenticating with model server """ logger.debug(locals()) mt_bench_answers.generate_answers( self.model_name, server_url, + api_key=api_key, output_dir=self.output_dir, max_workers=self.max_workers, ) - def judge_answers(self, server_url) -> tuple: + def judge_answers(self, server_url, api_key: str | None = None) -> tuple: """ Runs MT-Bench judgment Attributes server_url Model server endpoint (Ex: http://localhost:8000/v1) for the judge model + api_key API token for authenticating with model server Returns: overall_score MT-Bench score for the overall model evaluation @@ -126,6 +129,7 @@ def judge_answers(self, server_url) -> tuple: self.model_name, self.judge_model_name, server_url, + api_key=api_key, max_workers=self.max_workers, output_dir=self.output_dir, merge_system_user_message=self.merge_system_user_message, @@ -171,12 +175,13 @@ def __init__( self.taxonomy_git_repo_path = taxonomy_git_repo_path self.branch = branch - def gen_answers(self, server_url) -> None: + def gen_answers(self, server_url, api_key: str | None = None) -> None: """ Asks questions to model Attributes server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated + api_key API token for authenticating with model server """ logger.debug(locals()) mt_bench_branch_generator.generate( @@ -188,6 +193,7 @@ def gen_answers(self, server_url) -> None: mt_bench_answers.generate_answers( self.model_name, server_url, + api_key=api_key, branch=self.branch, output_dir=self.output_dir, data_dir=self.output_dir, @@ -195,12 +201,13 @@ def gen_answers(self, server_url) -> None: bench_name="mt_bench_branch", ) - def judge_answers(self, server_url) -> tuple: + def judge_answers(self, server_url, api_key: str | None = None) -> tuple: """ Runs MT-Bench-Branch judgment. Judgments can be compared across runs with consistent question_id -> qna file name. Attributes server_url Model server endpoint (Ex: http://localhost:8000/v1) for the judge model + api_key API token for authenticating with model server Returns: qa_pairs Question and answer pairs (with scores) from the evaluation @@ -210,6 +217,7 @@ def judge_answers(self, server_url) -> tuple: self.model_name, self.judge_model_name, server_url, + api_key=api_key, branch=self.branch, max_workers=self.max_workers, output_dir=self.output_dir, diff --git a/src/instructlab/eval/mt_bench_answers.py b/src/instructlab/eval/mt_bench_answers.py index 79fc128..e80f996 100644 --- a/src/instructlab/eval/mt_bench_answers.py +++ b/src/instructlab/eval/mt_bench_answers.py @@ -8,7 +8,6 @@ # Third Party # TODO need to look into this dependency from fastchat.model.model_adapter import get_conversation_template # type: ignore -import openai import shortuuid import tqdm @@ -17,6 +16,7 @@ from .mt_bench_common import ( bench_dir, chat_completion_openai, + get_openai_client, load_questions, temperature_config, ) @@ -98,6 +98,7 @@ def get_answer( def generate_answers( model_name, model_api_base, + api_key=None, branch=None, output_dir="eval_output", data_dir=None, @@ -111,7 +112,8 @@ def generate_answers( ): """Generate model answers to be judged""" logger.debug(locals()) - openai_client = openai.OpenAI(base_url=model_api_base, api_key="NO_API_KEY") + + openai_client = get_openai_client(model_api_base, api_key) if data_dir is None: data_dir = os.path.join(os.path.dirname(__file__), "data") diff --git a/src/instructlab/eval/mt_bench_common.py b/src/instructlab/eval/mt_bench_common.py index a9c52c6..248cfac 100644 --- a/src/instructlab/eval/mt_bench_common.py +++ b/src/instructlab/eval/mt_bench_common.py @@ -365,3 +365,10 @@ def check_data(questions, model_answers, ref_answers, models, judges): def get_model_list(answer_file): logger.debug(locals()) return [os.path.splitext(os.path.basename(answer_file))[0]] + + +def get_openai_client(model_api_base, api_key): + if api_key is None: + api_key = "NO_API_KEY" + openai_client = openai.OpenAI(base_url=model_api_base, api_key=api_key) + return openai_client diff --git a/src/instructlab/eval/mt_bench_judgment.py b/src/instructlab/eval/mt_bench_judgment.py index 0d24012..53ba315 100644 --- a/src/instructlab/eval/mt_bench_judgment.py +++ b/src/instructlab/eval/mt_bench_judgment.py @@ -6,7 +6,6 @@ # Third Party from tqdm import tqdm import numpy as np -import openai import pandas as pd # Local @@ -18,6 +17,7 @@ bench_dir, check_data, get_model_list, + get_openai_client, load_judge_prompts, load_model_answers, load_questions, @@ -278,6 +278,7 @@ def generate_judgment( model_name, judge_model_name, model_api_base, + api_key=None, bench_name="mt_bench", output_dir="eval_output", data_dir=None, @@ -288,7 +289,8 @@ def generate_judgment( ): """Generate judgment with scores and qa_pairs for a model""" logger.debug(locals()) - openai_client = openai.OpenAI(base_url=model_api_base, api_key="NO_API_KEY") + + openai_client = get_openai_client(model_api_base, api_key) first_n_env = os.environ.get("INSTRUCTLAB_EVAL_FIRST_N_QUESTIONS") if first_n_env is not None and first_n is None: