From 6182b3260868032830709ed181b2dad57480cc54 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Mon, 4 Mar 2024 16:44:08 -0800 Subject: [PATCH 1/2] Adjusted anthropic import --- paperqa/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paperqa/__init__.py b/paperqa/__init__.py index f6a4a268a..44a458ad1 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -7,6 +7,7 @@ OpenAIEmbeddingModel, LangchainLLMModel, OpenAILLMModel, + AnthropicLLMModel, LlamaEmbeddingModel, NumpyVectorStore, LangchainVectorStore, @@ -26,6 +27,7 @@ "EmbeddingModel", "OpenAIEmbeddingModel", "OpenAILLMModel", + "AnthropicLLMModel", "LangchainLLMModel", "LlamaEmbeddingModel", "SentenceTransformerEmbeddingModel", From 398b0f1ae238a4840addb776615a8b23979a8793 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Mon, 4 Mar 2024 17:00:00 -0800 Subject: [PATCH 2/2] Fixed sys prompts --- paperqa/llms.py | 37 +++++++++++++++++++++++++++++++------ paperqa/version.py | 2 +- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index db14da8be..ea31091d4 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -373,18 +373,43 @@ def set_model_name(cls, data: Any) -> Any: async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: aclient = self._check_client(client) - completion = await aclient.messages.create( - messages=messages, **process_llm_config(self.config, "max_tokens") + # filter out system + sys_message = next( + (m["content"] for m in messages if m["role"] == "system"), None ) + # BECAUISE THEY DO NOT USE NONE TO INDICATE SENTINEL + # LIKE ANY SANE PERSON + if sys_message: + completion = await aclient.messages.create( + system=sys_message, + messages=[m for m in messages if m["role"] != "system"], + **process_llm_config(self.config, "max_tokens"), + ) + else: + completion = await aclient.messages.create( + messages=[m for m in messages if m["role"] != "system"], + **process_llm_config(self.config, "max_tokens"), + ) return completion.content or "" async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: aclient = self._check_client(client) - completion = await aclient.messages.create( - messages=messages, - **process_llm_config(self.config, "max_tokens"), - stream=True, + sys_message = next( + (m["content"] for m in messages if m["role"] == "system"), None ) + if sys_message: + completion = await aclient.messages.create( + stream=True, + system=sys_message, + messages=[m for m in messages if m["role"] != "system"], + **process_llm_config(self.config, "max_tokens"), + ) + else: + completion = await aclient.messages.create( + stream=True, + messages=[m for m in messages if m["role"] != "system"], + **process_llm_config(self.config, "max_tokens"), + ) async for event in completion: if isinstance(event, ContentBlockDeltaEvent): yield event.delta.text diff --git a/paperqa/version.py b/paperqa/version.py index 216eb9671..f5ee359b8 100644 --- a/paperqa/version.py +++ b/paperqa/version.py @@ -1 +1 @@ -__version__ = "4.0.0-pre.10" +__version__ = "4.0.0-pre.11"