Skip to content

Commit

Permalink
Merge pull request #103 from danmcp/dont-retry-on-connection-failure
Browse files Browse the repository at this point in the history
Dont retry on connection failure
  • Loading branch information
nathan-weinberg authored Aug 21, 2024
2 parents 515263a + 7fbd87e commit ba6fe0e
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ scripts/infra/cloud-instance.sh ec2 ssh
# Regardless of how you setup your instance
git clone https://github.com/instructlab/taxonomy.git && pushd taxonomy && git branch rc && popd
git clone --bare https://github.com/instructlab/eval.git && git clone eval.git/ && cd eval && git remote add syncrepo ../eval.git
python -m venv venv
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
pip install -r requirements-dev.txt
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,6 @@ from-first = true
# import-heading-firstparty=First Party
# import-heading-localfolder=Local
known-local-folder = ["tuning"]

[tool.mypy]
ignore_missing_imports = true
12 changes: 12 additions & 0 deletions src/instructlab/eval/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,15 @@ def __init__(self, tasks_dir) -> None:
super().__init__()
self.tasks_dir = tasks_dir
self.message = f"Invalid Tasks Dir: {tasks_dir}"


class ModelServingAPIError(EvalError):
"""
Error raised when reply retrieval from model serving fails.
Attributes
message error message to be printed on raise
"""

def __init__(self) -> None:
super().__init__()
self.message = "Failed to receive a reply from model serving API."
96 changes: 81 additions & 15 deletions src/instructlab/eval/mt_bench_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

# Standard
from typing import Optional
from typing import Optional, TypedDict
import ast
import dataclasses
import glob
Expand All @@ -14,9 +14,13 @@
import time

# Third Party
from fastchat import conversation
from fastchat.model.model_adapter import get_conversation_template # type: ignore
import openai

# First Party
from instructlab.eval import exceptions

# Local
from .logger_config import setup_logger

Expand Down Expand Up @@ -247,37 +251,99 @@ def play_a_match_single(
return result


def _is_fatal_openai_error(e: openai.OpenAIError) -> bool:
return isinstance(
e,
(
openai.APIConnectionError,
openai.AuthenticationError,
openai.PermissionDeniedError,
openai.NotFoundError,
),
)


# TODO: copied from instructlab (cli) utils module; consolidate somewhere?
class Message(TypedDict):
"""
Represents a message within an AI conversation.
"""

content: str
# one of: "user", "assistant", or "system"
role: str


def _get_messages(
conv: conversation.Conversation, merge_system_user_message: bool
) -> list[Message]:
messages = conv.to_openai_api_messages()
if (
(merge_system_user_message or conv.name == "mistral")
and messages[0]["role"] == "system"
and messages[1]["role"] == "user"
):
messages[1]["content"] = messages[0]["content"] + "\n" + messages[1]["content"]
return messages[1:]
return messages


def chat_completion_openai(
openai_client, model, conv, temperature, max_tokens, merge_system_user_message=False
openai_client,
model,
conv: conversation.Conversation,
temperature,
max_tokens,
merge_system_user_message: bool = False,
) -> str:
output = None
messages = _get_messages(conv, merge_system_user_message)

for i in range(API_MAX_RETRY):
try:
messages = conv.to_openai_api_messages()
if (
(merge_system_user_message or conv.name == "mistral")
and messages[0]["role"] == "system"
and messages[1]["role"] == "user"
):
messages[1]["content"] = (
messages[0]["content"] + "\n" + messages[1]["content"]
)
messages = messages[1:]
response = openai_client.chat.completions.create(
model=model,
messages=messages,
n=1,
temperature=temperature,
max_tokens=max_tokens,
)
return response.choices[0].message.content
except openai.OpenAIError as e:
output = response.choices[0].message.content
break
except (
# retry won't fix these errors
openai.BadRequestError, # 400
openai.UnprocessableEntityError, # 422
) as e:
logger.debug(e)
return API_ERROR_OUTPUT # immediately soft fail
except (
# retry may help with these errors
openai.APIConnectionError,
openai.RateLimitError, # 429
openai.InternalServerError, # >=500
# NOTE: Errors listed below may need a revisit: we are not sure if
# it's ever helpful to retry them. Leaving them intact for now.
openai.AuthenticationError, # 401
openai.PermissionDeniedError, # 403
openai.NotFoundError, # 404
# General catch-all
openai.OpenAIError,
) as e:
if not _is_fatal_openai_error(e):
output = API_ERROR_OUTPUT # disable hard fail (never raise!)
# still, retry in the hope we'll get a successful reply
if i == API_MAX_RETRY - 1:
logger.error(e)
break
logger.debug(e)
time.sleep(API_RETRY_SLEEP)

return API_ERROR_OUTPUT
if output is None:
# not a single attempt was non-fatal; this is indicative of
# basic connectivity or server issue -> hard fail
raise exceptions.ModelServingAPIError
return output


def check_data(questions, model_answers, ref_answers, models, judges):
Expand Down

0 comments on commit ba6fe0e

Please sign in to comment.