Skip to content

Commit

Permalink
Fixed sys prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Mar 5, 2024
1 parent 6182b32 commit 398b0f1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
37 changes: 31 additions & 6 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,18 +373,43 @@ def set_model_name(cls, data: Any) -> Any:

async def achat(self, client: Any, messages: list[dict[str, str]]) -> str:
aclient = self._check_client(client)
completion = await aclient.messages.create(
messages=messages, **process_llm_config(self.config, "max_tokens")
# filter out system
sys_message = next(
(m["content"] for m in messages if m["role"] == "system"), None
)
# BECAUISE THEY DO NOT USE NONE TO INDICATE SENTINEL
# LIKE ANY SANE PERSON
if sys_message:
completion = await aclient.messages.create(
system=sys_message,
messages=[m for m in messages if m["role"] != "system"],
**process_llm_config(self.config, "max_tokens"),
)
else:
completion = await aclient.messages.create(
messages=[m for m in messages if m["role"] != "system"],
**process_llm_config(self.config, "max_tokens"),
)
return completion.content or ""

async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any:
aclient = self._check_client(client)
completion = await aclient.messages.create(
messages=messages,
**process_llm_config(self.config, "max_tokens"),
stream=True,
sys_message = next(
(m["content"] for m in messages if m["role"] == "system"), None
)
if sys_message:
completion = await aclient.messages.create(
stream=True,
system=sys_message,
messages=[m for m in messages if m["role"] != "system"],
**process_llm_config(self.config, "max_tokens"),
)
else:
completion = await aclient.messages.create(
stream=True,
messages=[m for m in messages if m["role"] != "system"],
**process_llm_config(self.config, "max_tokens"),
)
async for event in completion:
if isinstance(event, ContentBlockDeltaEvent):
yield event.delta.text
Expand Down
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "4.0.0-pre.10"
__version__ = "4.0.0-pre.11"

0 comments on commit 398b0f1

Please sign in to comment.