diff --git a/pyproject.toml b/pyproject.toml index b11c7bd..03faef9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,3 +96,6 @@ from-first = true # import-heading-firstparty=First Party # import-heading-localfolder=Local known-local-folder = ["tuning"] + +[tool.mypy] +ignore_missing_imports = true diff --git a/src/instructlab/eval/mt_bench_common.py b/src/instructlab/eval/mt_bench_common.py index aa0b901..2b80efb 100644 --- a/src/instructlab/eval/mt_bench_common.py +++ b/src/instructlab/eval/mt_bench_common.py @@ -4,7 +4,7 @@ """ # Standard -from typing import Optional +from typing import Optional, TypedDict import ast import dataclasses import glob @@ -14,6 +14,7 @@ import time # Third Party +from fastchat import conversation from fastchat.model.model_adapter import get_conversation_template # type: ignore import openai @@ -254,23 +255,44 @@ def _is_fatal_openai_error(e: openai.OpenAIError) -> bool: return isinstance(e, openai.APIConnectionError) +# TODO: copied from instructlab (cli) utils module; consolidate somewhere? +class Message(TypedDict): + """ + Represents a message within an AI conversation. + """ + + content: str + # one of: "user", "assistant", or "system" + role: str + + +def _get_messages( + conv: conversation.Conversation, merge_system_user_message: bool +) -> list[Message]: + messages = conv.to_openai_api_messages() + if ( + merge_system_user_message + and messages[0]["role"] == "system" + and messages[1]["role"] == "user" + ): + messages[1]["content"] = messages[0]["content"] + "\n" + messages[1]["content"] + return messages[1:] + return messages + + def chat_completion_openai( - openai_client, model, conv, temperature, max_tokens, merge_system_user_message=False + openai_client, + model, + conv: conversation.Conversation, + temperature, + max_tokens, + merge_system_user_message: bool = False, ) -> str: output = None + messages = _get_messages(conv, merge_system_user_message) for i in range(API_MAX_RETRY): try: - messages = conv.to_openai_api_messages() - if ( - merge_system_user_message - and messages[0]["role"] == "system" - and messages[1]["role"] == "user" - ): - messages[1]["content"] = ( - messages[0]["content"] + "\n" + messages[1]["content"] - ) - messages = messages[1:] response = openai_client.chat.completions.create( model=model, messages=messages,