From a21e3bb56ee1d8e2e8e115b979ca529f713c3dd0 Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Wed, 31 Jul 2024 21:04:05 +0000 Subject: [PATCH] Calculate messages for openai completion once Before the patch, we were calculating them on every retry attempt. The function is pure, so there is no good reason to repeat the calculation. This also simplifies the function a bit. Signed-off-by: Ihar Hrachyshka --- pyproject.toml | 3 ++ src/instructlab/eval/mt_bench_common.py | 46 ++++++++++++++++++------- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b11c7bd..03faef9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,3 +96,6 @@ from-first = true # import-heading-firstparty=First Party # import-heading-localfolder=Local known-local-folder = ["tuning"] + +[tool.mypy] +ignore_missing_imports = true diff --git a/src/instructlab/eval/mt_bench_common.py b/src/instructlab/eval/mt_bench_common.py index aa0b901..2b80efb 100644 --- a/src/instructlab/eval/mt_bench_common.py +++ b/src/instructlab/eval/mt_bench_common.py @@ -4,7 +4,7 @@ """ # Standard -from typing import Optional +from typing import Optional, TypedDict import ast import dataclasses import glob @@ -14,6 +14,7 @@ import time # Third Party +from fastchat import conversation from fastchat.model.model_adapter import get_conversation_template # type: ignore import openai @@ -254,23 +255,44 @@ def _is_fatal_openai_error(e: openai.OpenAIError) -> bool: return isinstance(e, openai.APIConnectionError) +# TODO: copied from instructlab (cli) utils module; consolidate somewhere? +class Message(TypedDict): + """ + Represents a message within an AI conversation. + """ + + content: str + # one of: "user", "assistant", or "system" + role: str + + +def _get_messages( + conv: conversation.Conversation, merge_system_user_message: bool +) -> list[Message]: + messages = conv.to_openai_api_messages() + if ( + merge_system_user_message + and messages[0]["role"] == "system" + and messages[1]["role"] == "user" + ): + messages[1]["content"] = messages[0]["content"] + "\n" + messages[1]["content"] + return messages[1:] + return messages + + def chat_completion_openai( - openai_client, model, conv, temperature, max_tokens, merge_system_user_message=False + openai_client, + model, + conv: conversation.Conversation, + temperature, + max_tokens, + merge_system_user_message: bool = False, ) -> str: output = None + messages = _get_messages(conv, merge_system_user_message) for i in range(API_MAX_RETRY): try: - messages = conv.to_openai_api_messages() - if ( - merge_system_user_message - and messages[0]["role"] == "system" - and messages[1]["role"] == "user" - ): - messages[1]["content"] = ( - messages[0]["content"] + "\n" + messages[1]["content"] - ) - messages = messages[1:] response = openai_client.chat.completions.create( model=model, messages=messages,