Skip to content

Commit

Permalink
Calculate messages for openai completion once
Browse files Browse the repository at this point in the history
Before the patch, we were calculating them on every retry attempt. The
function is pure, so there is no good reason to repeat the calculation.

This also simplifies the function a bit.

Signed-off-by: Ihar Hrachyshka <[email protected]>
  • Loading branch information
booxter authored and danmcp committed Aug 16, 2024
1 parent ed2807b commit a21e3bb
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 12 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,6 @@ from-first = true
# import-heading-firstparty=First Party
# import-heading-localfolder=Local
known-local-folder = ["tuning"]

[tool.mypy]
ignore_missing_imports = true
46 changes: 34 additions & 12 deletions src/instructlab/eval/mt_bench_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

# Standard
from typing import Optional
from typing import Optional, TypedDict
import ast
import dataclasses
import glob
Expand All @@ -14,6 +14,7 @@
import time

# Third Party
from fastchat import conversation
from fastchat.model.model_adapter import get_conversation_template # type: ignore
import openai

Expand Down Expand Up @@ -254,23 +255,44 @@ def _is_fatal_openai_error(e: openai.OpenAIError) -> bool:
return isinstance(e, openai.APIConnectionError)


# TODO: copied from instructlab (cli) utils module; consolidate somewhere?
class Message(TypedDict):
"""
Represents a message within an AI conversation.
"""

content: str
# one of: "user", "assistant", or "system"
role: str


def _get_messages(
conv: conversation.Conversation, merge_system_user_message: bool
) -> list[Message]:
messages = conv.to_openai_api_messages()
if (
merge_system_user_message
and messages[0]["role"] == "system"
and messages[1]["role"] == "user"
):
messages[1]["content"] = messages[0]["content"] + "\n" + messages[1]["content"]
return messages[1:]
return messages


def chat_completion_openai(
openai_client, model, conv, temperature, max_tokens, merge_system_user_message=False
openai_client,
model,
conv: conversation.Conversation,
temperature,
max_tokens,
merge_system_user_message: bool = False,
) -> str:
output = None
messages = _get_messages(conv, merge_system_user_message)

for i in range(API_MAX_RETRY):
try:
messages = conv.to_openai_api_messages()
if (
merge_system_user_message
and messages[0]["role"] == "system"
and messages[1]["role"] == "user"
):
messages[1]["content"] = (
messages[0]["content"] + "\n" + messages[1]["content"]
)
messages = messages[1:]
response = openai_client.chat.completions.create(
model=model,
messages=messages,
Expand Down

0 comments on commit a21e3bb

Please sign in to comment.