diff --git a/paperqa/docs.py b/paperqa/docs.py index 1128ab0d3..39115db13 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -31,6 +31,7 @@ VectorStore, VoyageAIEmbeddingModel, get_score, + is_anyscale_model, llm_model_factory, vector_store_factory, ) @@ -205,12 +206,10 @@ def set_client( embedding_client: Any | None = None, ): if client is None and isinstance(self.llm_model, OpenAILLMModel): - if (api_key := os.environ.get("ANYSCALE_API_KEY")) and ( - base_url := os.environ.get("ANYSCALE_BASE_URL") - ): + if is_anyscale_model(self.llm_model.name): client = AsyncOpenAI( - api_key=api_key, - base_url=base_url, + api_key=os.environ["ANYSCALE_API_KEY"], + base_url=os.environ["ANYSCALE_BASE_URL"], ) else: client = AsyncOpenAI() diff --git a/paperqa/llms.py b/paperqa/llms.py index b81ade5b1..76773ae1b 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -49,6 +49,12 @@ # return model_name in model_arr or model_name in complete_model_arr +ANYSCALE_MODEL_PREFIXES: tuple[str, ...] = ( + "meta-llama/Meta-Llama-3-", + "mistralai/Mistral-", + "mistralai/Mixtral-", +) + def guess_model_type(model_name: str) -> str: # noqa: PLR0911 if model_name.startswith("babbage"): @@ -58,13 +64,7 @@ def guess_model_type(model_name: str) -> str: # noqa: PLR0911 if ( os.environ.get("ANYSCALE_API_KEY") and os.environ.get("ANYSCALE_BASE_URL") - and model_name.startswith("meta-llama/Meta-Llama-3-") - ): - return "chat" - if ( - os.environ.get("ANYSCALE_API_KEY") - and os.environ.get("ANYSCALE_BASE_URL") - and (model_name.startswith(("mistralai/Mistral-", "mistralai/Mixtral-"))) + and (model_name.startswith(ANYSCALE_MODEL_PREFIXES)) ): return "chat" if "instruct" in model_name: @@ -78,17 +78,22 @@ def guess_model_type(model_name: str) -> str: # noqa: PLR0911 return "completion" -def is_openai_model(model_name) -> bool: - open_ai_model_prefixes = {"gpt-", "babbage", "davinci", "ft:gpt-"} - # add special prefixes if the user has anyscale models +def is_anyscale_model(model_name: str) -> bool: + # compares prefixes with anyscale models # https://docs.anyscale.com/endpoints/text-generation/query-a-model/ - if os.environ.get("ANYSCALE_API_KEY") and os.environ.get("ANYSCALE_BASE_URL"): - open_ai_model_prefixes |= { - "meta-llama/Meta-Llama-3-", - "mistralai/Mistral-", - "mistralai/Mixtral-", - } - return model_name.startswith(tuple(open_ai_model_prefixes)) + if ( + os.environ.get("ANYSCALE_API_KEY") + and os.environ.get("ANYSCALE_BASE_URL") + and model_name.startswith(ANYSCALE_MODEL_PREFIXES) + ): + return True + return False + + +def is_openai_model(model_name: str) -> bool: + return is_anyscale_model(model_name) or model_name.startswith( + ("gpt-", "babbage", "davinci", "ft:gpt-") + ) def process_llm_config(