Skip to content

Commit

Permalink
Merge pull request #128 from sallyom/add-api-key
Browse files Browse the repository at this point in the history
add option to pass 'api_key' to gen_answers, judge_answers
  • Loading branch information
mergify[bot] authored Sep 13, 2024
2 parents 6b3495b + 3445ce0 commit 83f9d95
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
16 changes: 12 additions & 4 deletions src/instructlab/eval/mt_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,27 +94,30 @@ class MTBenchEvaluator(AbstractMTBenchEvaluator):

name = "mt_bench"

def gen_answers(self, server_url) -> None:
def gen_answers(self, server_url, api_key: str | None = None) -> None:
"""
Asks questions to model
Attributes
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated
api_key API token for authenticating with model server
"""
logger.debug(locals())
mt_bench_answers.generate_answers(
self.model_name,
server_url,
api_key=api_key,
output_dir=self.output_dir,
max_workers=self.max_workers,
)

def judge_answers(self, server_url) -> tuple:
def judge_answers(self, server_url, api_key: str | None = None) -> tuple:
"""
Runs MT-Bench judgment
Attributes
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the judge model
api_key API token for authenticating with model server
Returns:
overall_score MT-Bench score for the overall model evaluation
Expand All @@ -126,6 +129,7 @@ def judge_answers(self, server_url) -> tuple:
self.model_name,
self.judge_model_name,
server_url,
api_key=api_key,
max_workers=self.max_workers,
output_dir=self.output_dir,
merge_system_user_message=self.merge_system_user_message,
Expand Down Expand Up @@ -171,12 +175,13 @@ def __init__(
self.taxonomy_git_repo_path = taxonomy_git_repo_path
self.branch = branch

def gen_answers(self, server_url) -> None:
def gen_answers(self, server_url, api_key: str | None = None) -> None:
"""
Asks questions to model
Attributes
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated
api_key API token for authenticating with model server
"""
logger.debug(locals())
mt_bench_branch_generator.generate(
Expand All @@ -188,19 +193,21 @@ def gen_answers(self, server_url) -> None:
mt_bench_answers.generate_answers(
self.model_name,
server_url,
api_key=api_key,
branch=self.branch,
output_dir=self.output_dir,
data_dir=self.output_dir,
max_workers=self.max_workers,
bench_name="mt_bench_branch",
)

def judge_answers(self, server_url) -> tuple:
def judge_answers(self, server_url, api_key: str | None = None) -> tuple:
"""
Runs MT-Bench-Branch judgment. Judgments can be compared across runs with consistent question_id -> qna file name.
Attributes
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the judge model
api_key API token for authenticating with model server
Returns:
qa_pairs Question and answer pairs (with scores) from the evaluation
Expand All @@ -210,6 +217,7 @@ def judge_answers(self, server_url) -> tuple:
self.model_name,
self.judge_model_name,
server_url,
api_key=api_key,
branch=self.branch,
max_workers=self.max_workers,
output_dir=self.output_dir,
Expand Down
6 changes: 4 additions & 2 deletions src/instructlab/eval/mt_bench_answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# Third Party
# TODO need to look into this dependency
from fastchat.model.model_adapter import get_conversation_template # type: ignore
import openai
import shortuuid
import tqdm

Expand All @@ -17,6 +16,7 @@
from .mt_bench_common import (
bench_dir,
chat_completion_openai,
get_openai_client,
load_questions,
temperature_config,
)
Expand Down Expand Up @@ -98,6 +98,7 @@ def get_answer(
def generate_answers(
model_name,
model_api_base,
api_key=None,
branch=None,
output_dir="eval_output",
data_dir=None,
Expand All @@ -111,7 +112,8 @@ def generate_answers(
):
"""Generate model answers to be judged"""
logger.debug(locals())
openai_client = openai.OpenAI(base_url=model_api_base, api_key="NO_API_KEY")

openai_client = get_openai_client(model_api_base, api_key)

if data_dir is None:
data_dir = os.path.join(os.path.dirname(__file__), "data")
Expand Down
7 changes: 7 additions & 0 deletions src/instructlab/eval/mt_bench_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,10 @@ def check_data(questions, model_answers, ref_answers, models, judges):
def get_model_list(answer_file):
logger.debug(locals())
return [os.path.splitext(os.path.basename(answer_file))[0]]


def get_openai_client(model_api_base, api_key):
if api_key is None:
api_key = "NO_API_KEY"
openai_client = openai.OpenAI(base_url=model_api_base, api_key=api_key)
return openai_client
6 changes: 4 additions & 2 deletions src/instructlab/eval/mt_bench_judgment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# Third Party
from tqdm import tqdm
import numpy as np
import openai
import pandas as pd

# Local
Expand All @@ -18,6 +17,7 @@
bench_dir,
check_data,
get_model_list,
get_openai_client,
load_judge_prompts,
load_model_answers,
load_questions,
Expand Down Expand Up @@ -278,6 +278,7 @@ def generate_judgment(
model_name,
judge_model_name,
model_api_base,
api_key=None,
bench_name="mt_bench",
output_dir="eval_output",
data_dir=None,
Expand All @@ -288,7 +289,8 @@ def generate_judgment(
):
"""Generate judgment with scores and qa_pairs for a model"""
logger.debug(locals())
openai_client = openai.OpenAI(base_url=model_api_base, api_key="NO_API_KEY")

openai_client = get_openai_client(model_api_base, api_key)

first_n_env = os.environ.get("INSTRUCTLAB_EVAL_FIRST_N_QUESTIONS")
if first_n_env is not None and first_n is None:
Expand Down

0 comments on commit 83f9d95

Please sign in to comment.