diff --git a/src/instructlab/eval/exceptions.py b/src/instructlab/eval/exceptions.py index 65a82f3..c6c71fe 100644 --- a/src/instructlab/eval/exceptions.py +++ b/src/instructlab/eval/exceptions.py @@ -90,3 +90,15 @@ def __init__(self, tasks_dir) -> None: super().__init__() self.tasks_dir = tasks_dir self.message = f"Invalid Tasks Dir: {tasks_dir}" + + +class OpenAIError(EvalError): + """ + Error raised when reply retrieval from OpenAI API fails. + Attributes + message error message to be printed on raise + """ + + def __init__(self) -> None: + super().__init__() + self.message = "Failed to receive a reply from API." diff --git a/src/instructlab/eval/mt_bench_common.py b/src/instructlab/eval/mt_bench_common.py index d346999..149c9d4 100644 --- a/src/instructlab/eval/mt_bench_common.py +++ b/src/instructlab/eval/mt_bench_common.py @@ -17,6 +17,9 @@ from fastchat.model.model_adapter import get_conversation_template # type: ignore import openai +# First Party +from instructlab.eval import exceptions + # Local from .logger_config import setup_logger @@ -247,10 +250,15 @@ def play_a_match_single( return result +def _is_fatal_openai_error(e: openai.OpenAIError) -> bool: + return isinstance(e, openai.APIConnectionError) + + def chat_completion_openai( openai_client, model, conv, temperature, max_tokens, merge_system_user_message=False ) -> str: - output = API_ERROR_OUTPUT + output = None + for i in range(API_MAX_RETRY): try: messages = conv.to_openai_api_messages() @@ -272,14 +280,38 @@ def chat_completion_openai( ) output = response.choices[0].message.content break - except openai.OpenAIError as e: + except ( + # retry may help with these errors + openai.APIConnectionError, + openai.RateLimitError, # 429 + openai.InternalServerError, # >=500 + # NOTE: Errors listed below may need a revisit: we are not sure if + # it's ever helpful to retry them. Leaving them intact for now. + openai.AuthenticationError, # 401 + openai.PermissionDeniedError, # 403 + openai.NotFoundError, # 404 + ) as e: + if not _is_fatal_openai_error(e): + output = API_ERROR_OUTPUT # disable hard fail (never raise!) + # still, retry in the hope we'll get a successful reply if i == API_MAX_RETRY - 1: # Print error on last try print(type(e), e) else: logger.debug(e) time.sleep(API_RETRY_SLEEP) - + except ( + # retry won't fix these errors + openai.BadRequestError, # 400 + openai.UnprocessableEntityError, # 422 + ) as e: + logger.debug(e) + return API_ERROR_OUTPUT # immediately soft fail + + if output is None: + # not a single attempt was non-fatal; this is indicative of + # basic connectivity or server issue -> hard fail + raise exceptions.OpenAIError return output