Skip to content

Commit

Permalink
Don't retry on connection error
Browse files Browse the repository at this point in the history
Note: TimeoutError is a subclass of the generic connection error, and
we'd like to retry for timeouts.

This patch also rearranges the code a bit, incl. making sure that it
won't sleep at the very last iteration when we know there won't be
another retry attempt.

Closes: #77

Signed-off-by: Ihar Hrachyshka <[email protected]>
  • Loading branch information
booxter committed Jul 23, 2024
1 parent dbf4db2 commit baae1c3
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ scripts/infra/cloud-instance.sh ec2 ssh
# Regardless of how you setup your instance
git clone https://github.com/instructlab/taxonomy.git && pushd taxonomy && git branch rc && popd
git clone --bare https://github.com/instructlab/eval.git && git clone eval.git/ && cd eval && git remote add syncrepo ../eval.git
python -m venv venv
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
pip install -r requirements-dev.txt
Expand Down Expand Up @@ -104,4 +104,4 @@ eval_output/
└── reference_answer
└── instructlab
└── granite-7b-lab.jsonl
```
```
12 changes: 12 additions & 0 deletions src/instructlab/eval/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,15 @@ def __init__(self, tasks_dir) -> None:
super().__init__()
self.tasks_dir = tasks_dir
self.message = f"Invalid Tasks Dir: {tasks_dir}"


class CompletionError(EvalError):
"""
Error raised when the model failed to provide a completion.
Attributes
message error message to be printed on raise
"""

def __init__(self) -> None:
super().__init__()
self.message = "Failed to retrieve completion."
28 changes: 15 additions & 13 deletions src/instructlab/eval/mt_bench_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from fastchat.model.model_adapter import get_conversation_template # type: ignore
import openai

# First Party
from instructlab.eval import exceptions

# Local
from .logger_config import setup_logger

Expand All @@ -25,7 +28,6 @@
# API setting constants
API_MAX_RETRY = 4
API_RETRY_SLEEP = 4
API_ERROR_OUTPUT = "$ERROR$"

# Categories that need reference answers
NEED_REF_CATS = ["math", "reasoning", "coding", "arena-hard-200", "taxonomy"]
Expand Down Expand Up @@ -250,9 +252,10 @@ def play_a_match_single(
def chat_completion_openai(
openai_client, model, conv, temperature, max_tokens, merge_system_user_message=False
) -> str:
output = API_ERROR_OUTPUT
for i in range(API_MAX_RETRY):
try:
if i > 0:
time.sleep(API_RETRY_SLEEP)
messages = conv.to_openai_api_messages()
if (
merge_system_user_message
Expand All @@ -270,17 +273,16 @@ def chat_completion_openai(
temperature=temperature,
max_tokens=max_tokens,
)
output = response.choices[0].message.content
break
except openai.OpenAIError as e:
if i == API_MAX_RETRY - 1:
# Print error on last try
print(type(e), e)
else:
logger.debug(e)
time.sleep(API_RETRY_SLEEP)

return output
return response.choices[0].message.content
except openai.APITimeoutError as ex:
logger.debug(ex)
except openai.APIConnectionError as ex:
logger.debug(ex)
raise exceptions.EvalError("Failed to connect to API.")
except openai.OpenAIError as ex:
logger.debug(ex)

raise exceptions.EvalError("Failed to get completion from API.")


def check_data(questions, model_answers, ref_answers, models, judges):
Expand Down

0 comments on commit baae1c3

Please sign in to comment.