From d11a178189411f03e32e9bc636647eaa306f6df9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 29 Oct 2024 17:56:26 +0100 Subject: [PATCH] feat: use custom http_client MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit adds the ability to pass a custom HTTP client to the MT-Bench evaluator. This is handy when using custom certificates when interacting with the judge model serving endpoint. Signed-off-by: Sébastien Han --- .spellcheck-en-custom.txt | 1 + CHANGELOG.md | 4 ++++ requirements.txt | 1 + src/instructlab/eval/mt_bench.py | 14 ++++++++++++++ src/instructlab/eval/mt_bench_answers.py | 3 ++- src/instructlab/eval/mt_bench_common.py | 12 ++++++++++-- src/instructlab/eval/mt_bench_judgment.py | 3 ++- ..._branch_gen_answers_with_custom_http_client.py | 15 +++++++++++++++ 8 files changed, 49 insertions(+), 4 deletions(-) create mode 100755 tests/test_branch_gen_answers_with_custom_http_client.py diff --git a/.spellcheck-en-custom.txt b/.spellcheck-en-custom.txt index 769b05b..33a582f 100644 --- a/.spellcheck-en-custom.txt +++ b/.spellcheck-en-custom.txt @@ -10,6 +10,7 @@ dr eval gpt hoc +http instructlab jsonl justfile diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ba63cc..a897297 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1 +1,5 @@ +## 0.4 + +* Added ability to specify a custom http client to MT-Bench + ## v0.2 diff --git a/requirements.txt b/requirements.txt index 9be7cbd..a3e6e7d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ accelerate pandas pandas-stubs lm-eval>=0.4.4 +httpx diff --git a/src/instructlab/eval/mt_bench.py b/src/instructlab/eval/mt_bench.py index 2d9a12a..cf6fd58 100644 --- a/src/instructlab/eval/mt_bench.py +++ b/src/instructlab/eval/mt_bench.py @@ -10,6 +10,8 @@ import multiprocessing import os +import httpx + # First Party from instructlab.eval import ( mt_bench_answers, @@ -110,6 +112,7 @@ def gen_answers( api_key: str | None = None, max_workers: int | str | None = None, serving_gpus: int | None = None, + http_client: httpx.Client | None = None, ) -> None: """ Asks questions to model @@ -119,6 +122,7 @@ def gen_answers( api_key API token for authenticating with model server max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor. serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor. + http_client Custom http client to use for requests """ logger.debug(locals()) mt_bench_answers.generate_answers( @@ -127,6 +131,7 @@ def gen_answers( api_key=api_key, output_dir=self.output_dir, max_workers=self._get_effective_max_workers(max_workers, serving_gpus), + http_client=http_client, ) def judge_answers( @@ -135,6 +140,7 @@ def judge_answers( api_key: str | None = None, max_workers: int | str | None = None, serving_gpus: int | None = None, + http_client: httpx.Client | None = None, ) -> tuple: """ Runs MT-Bench judgment @@ -144,6 +150,7 @@ def judge_answers( api_key API token for authenticating with model server max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor. serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor. + http_client Custom http client to use for requests Returns: overall_score MT-Bench score for the overall model evaluation @@ -160,6 +167,7 @@ def judge_answers( max_workers=self._get_effective_max_workers(max_workers, serving_gpus), output_dir=self.output_dir, merge_system_user_message=self.merge_system_user_message, + http_client=http_client, ) @@ -202,6 +210,7 @@ def gen_answers( api_key: str | None = None, max_workers: int | str | None = None, serving_gpus: int | None = None, + http_client: httpx.Client | None = None, ) -> None: """ Asks questions to model @@ -211,6 +220,7 @@ def gen_answers( api_key API token for authenticating with model server max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor. serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor. + http_client Custom http client to use for requests """ logger.debug(locals()) mt_bench_branch_generator.generate( @@ -228,6 +238,7 @@ def gen_answers( data_dir=self.output_dir, max_workers=self._get_effective_max_workers(max_workers, serving_gpus), bench_name="mt_bench_branch", + http_client=http_client, ) def judge_answers( @@ -236,6 +247,7 @@ def judge_answers( api_key: str | None = None, max_workers: int | str | None = None, serving_gpus: int | None = None, + http_client: httpx.Client | None = None, ) -> tuple: """ Runs MT-Bench-Branch judgment. Judgments can be compared across runs with consistent question_id -> qna file name. @@ -245,6 +257,7 @@ def judge_answers( api_key API token for authenticating with model server max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor. serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor. + http_client Custom http client to use for requests Returns: overall_score Overall score from the evaluation @@ -263,5 +276,6 @@ def judge_answers( data_dir=self.output_dir, bench_name="mt_bench_branch", merge_system_user_message=self.merge_system_user_message, + http_client=http_client, ) return overall_score, qa_pairs, error_rate diff --git a/src/instructlab/eval/mt_bench_answers.py b/src/instructlab/eval/mt_bench_answers.py index ac6b98b..f4337b4 100644 --- a/src/instructlab/eval/mt_bench_answers.py +++ b/src/instructlab/eval/mt_bench_answers.py @@ -108,11 +108,12 @@ def generate_answers( max_tokens=1024, max_workers=1, bench_name="mt_bench", + http_client=None, ): """Generate model answers to be judged""" logger.debug(locals()) - openai_client = get_openai_client(model_api_base, api_key) + openai_client = get_openai_client(model_api_base, api_key, http_client) if data_dir is None: data_dir = os.path.join(os.path.dirname(__file__), "data") diff --git a/src/instructlab/eval/mt_bench_common.py b/src/instructlab/eval/mt_bench_common.py index f45bf5f..bfd31ef 100644 --- a/src/instructlab/eval/mt_bench_common.py +++ b/src/instructlab/eval/mt_bench_common.py @@ -12,6 +12,8 @@ import re import time +import httpx + # Third Party import openai @@ -365,8 +367,14 @@ def get_model_list(answer_file): return [os.path.splitext(os.path.basename(answer_file))[0]] -def get_openai_client(model_api_base, api_key): +def get_openai_client( + model_api_base, + api_key, + http_client: httpx.Client | None = None, +): if api_key is None: api_key = "NO_API_KEY" - openai_client = openai.OpenAI(base_url=model_api_base, api_key=api_key) + openai_client = openai.OpenAI( + base_url=model_api_base, api_key=api_key, http_client=http_client + ) return openai_client diff --git a/src/instructlab/eval/mt_bench_judgment.py b/src/instructlab/eval/mt_bench_judgment.py index 53ba315..f853a09 100644 --- a/src/instructlab/eval/mt_bench_judgment.py +++ b/src/instructlab/eval/mt_bench_judgment.py @@ -286,11 +286,12 @@ def generate_judgment( max_workers=1, first_n=None, merge_system_user_message=False, + http_client=None, ): """Generate judgment with scores and qa_pairs for a model""" logger.debug(locals()) - openai_client = get_openai_client(model_api_base, api_key) + openai_client = get_openai_client(model_api_base, api_key, http_client) first_n_env = os.environ.get("INSTRUCTLAB_EVAL_FIRST_N_QUESTIONS") if first_n_env is not None and first_n is None: diff --git a/tests/test_branch_gen_answers_with_custom_http_client.py b/tests/test_branch_gen_answers_with_custom_http_client.py new file mode 100755 index 0000000..12c2dd7 --- /dev/null +++ b/tests/test_branch_gen_answers_with_custom_http_client.py @@ -0,0 +1,15 @@ +# First Party +import httpx + +from instructlab.eval.mt_bench import MTBenchBranchEvaluator + +mt_bench_branch = MTBenchBranchEvaluator( + "instructlab/granite-7b-lab", + "instructlab/granite-7b-lab", + "../taxonomy", + "main", +) +mt_bench_branch.gen_answers( + "http://localhost:8000/v1", + http_client=httpx.Client(verify=False), +)