Skip to content

Commit

Permalink
Refactored LLMs to allow swapping
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Jan 5, 2024
1 parent 2a32876 commit 37b82f9
Show file tree
Hide file tree
Showing 3 changed files with 283 additions and 152 deletions.
110 changes: 71 additions & 39 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@
from typing import Any, BinaryIO, cast

from openai import AsyncOpenAI
from pydantic import BaseModel, Field, field_validator, model_validator

from .llms import embed_documents, get_score, guess_model_type, make_chain
from pydantic import BaseModel, Field, model_validator

from .llms import (
EmbeddingModel,
LLMModel,
OpenAIEmbeddingModel,
OpenAILLMModel,
get_score,
is_openai_model,
)
from .paths import PAPERQA_DIR
from .readers import read_doc
from .types import (
Expand Down Expand Up @@ -43,47 +50,68 @@
class Docs(BaseModel):
"""A collection of documents to be used for answering questions."""

_client: AsyncOpenAI | None
# ephemeral clients that should not be pickled
_client: Any | None
_embedding_client: Any | None
llm: str = "default"
summary_llm: str | None = None
llm_model: LLMModel = Field(default_factory=OpenAILLMModel)
summary_llm_model: LLMModel | None = Field(default=None, validate_default=True)
embedding: EmbeddingModel = OpenAIEmbeddingModel()
docs: dict[DocKey, Doc] = {}
texts: list[Text] = []
docnames: set[str] = set()
texts_index: VectorStore = NumpyVectorStore()
doc_index: VectorStore = NumpyVectorStore()
llm_config: dict = dict(model="gpt-3.5-turbo", model_type="chat", temperature=0.1)
summary_llm_config: dict | None = Field(default=None, validate_default=True)
name: str = "default"
index_path: Path | None = PAPERQA_DIR / name
embeddings_model: str = "text-embedding-ada-002"
batch_size: int = 1
max_concurrent: int = 5
deleted_dockeys: set[DocKey] = set()
prompts: PromptCollection = PromptCollection()
jit_texts_index: bool = False
# This is used to strip indirect citations that come up from the summary llm
strip_citations: bool = True
verbose: bool = False

def __init__(self, **data):
if "embedding_client" in data:
embedding_client = data.pop("embedding_client")
elif "client" in data:
embedding_client = data["client"]
else:
embedding_client = AsyncOpenAI()
if "client" in data:
client = data.pop("client")
else:
client = AsyncOpenAI()
super().__init__(**data)
self._client = client
self._embedding_client = embedding_client

@field_validator("llm_config", "summary_llm_config")
@model_validator(mode="before")
@classmethod
def llm_guess_model_type(cls, v: dict) -> dict:
if v is not None and "model_type" not in v:
v["model_type"] = guess_model_type(v["model"])
return v
def setup_alias_models(cls, data: Any) -> Any:
if isinstance(data, dict):
if "llm" in data and data["llm"] != "default":
if is_openai_model(data["llm"]):
data["llm_model"] = OpenAILLMModel(config=dict(model=data["llm"]))
else:
raise ValueError(f"Could not guess model type for {data['llm']}. ")
if "summary_llm" in data and data["summary_llm"] is not None:
if is_openai_model(data["summary_llm"]):
data["summary_llm_model"] = OpenAILLMModel(
config=dict(model=data["summary_llm"])
)
else:
raise ValueError(f"Could not guess model type for {data['llm']}. ")
return data

@model_validator(mode="after")
@classmethod
def config_summary_llm_conig(cls, data: Any) -> Any:
def config_summary_llm_config(cls, data: Any) -> Any:
if isinstance(data, Docs):
if data.summary_llm_config is None:
data.summary_llm_config = data.llm_config
if data.summary_llm_model is None:
data.summary_llm_model = data.llm_model
return data

def clear_docs(self):
Expand All @@ -95,16 +123,25 @@ def __getstate__(self):
state = super().__getstate__()
# remove client from private attributes
del state["__pydantic_private__"]["_client"]
del state["__pydantic_private__"]["_embedding_client"]
return state

def __setstate__(self, state):
super().__setstate__(state)
self._client = None
self._embedding_client = None

def set_client(self, client: AsyncOpenAI | None = None):
def set_client(
self,
client: AsyncOpenAI | None = None,
embedding_client: AsyncOpenAI | None = None,
):
if client is None:
client = AsyncOpenAI()
self._client = client
if embedding_client is None:
embedding_client = client
self._embedding_client = embedding_client

def _get_unique_name(self, docname: str) -> str:
"""Create a unique name given proposed name"""
Expand Down Expand Up @@ -181,10 +218,9 @@ def add(
dockey = md5sum(path)
if citation is None:
# skip system because it's too hesitant to answer
cite_chain = make_chain(
cite_chain = self.llm_model.make_chain(
client=self._client,
prompt=self.prompts.cite,
llm_config=cast(dict, self.summary_llm_config),
skip_system=True,
)
# peak first chunk
Expand Down Expand Up @@ -251,15 +287,15 @@ def add_texts(
doc.docname = new_docname
if texts[0].embedding is None:
text_embeddings = asyncio.run(
embed_documents(
self._client, [t.text for t in texts], self.embeddings_model
self.embedding.embed_documents(
self._embedding_client, [t.text for t in texts]
)
)
for i, t in enumerate(texts):
t.embedding = text_embeddings[i]
if doc.embedding is None:
doc.embedding = asyncio.run(
embed_documents(self._client, [doc.citation], self.embeddings_model)
self.embedding.embed_documents(self._embedding_client, [doc.citation])
)[0]
if not self.jit_texts_index:
self.texts_index.add_texts_and_embeddings(texts)
Expand Down Expand Up @@ -289,7 +325,7 @@ async def adoc_match(
) -> set[DocKey]:
"""Return a list of dockeys that match the query."""
query_vector = (
await embed_documents(self._client, [query], self.embeddings_model)
await self.embedding.embed_documents(self._embedding_client, [query])
)[0]
matches, _ = self.doc_index.max_marginal_relevance_search(
query_vector,
Expand All @@ -304,13 +340,17 @@ async def adoc_match(
try:
if (
rerank is None
and self.llm_config["model"].startswith("gpt-4")
and (
type(self.llm) == OpenAILLMModel
and cast(OpenAILLMModel, self)
.llm.config["model"]
.startswith("gpt-4")
)
or rerank is True
):
chain = make_chain(
chain = self.llm_model.make_chain(
client=self._client,
prompt=self.prompts.select,
llm_config=self.llm_config,
skip_system=True,
)
papers = [f"{d.docname}: {d.citation}" for d in matched_docs]
Expand Down Expand Up @@ -387,9 +427,7 @@ async def aget_evidence(
matches = self.texts
else:
query_vector = (
await embed_documents(
self._client, [answer.question], self.embeddings_model
)
await self.embedding.embed_documents(self._client, [answer.question])
)[0]
if marginal_relevance:
matches, _ = self.texts_index.max_marginal_relevance_search(
Expand Down Expand Up @@ -421,10 +459,9 @@ async def process(match):
context = match.text
score = 5
else:
summary_chain = make_chain(
summary_chain = self.summary_llm_model.make_chain(
client=self._client,
prompt=self.prompts.summary,
llm_config=cast(dict, self.summary_llm_config),
system_prompt=self.prompts.system,
)
# This is dangerous because it
Expand Down Expand Up @@ -548,10 +585,9 @@ async def aquery(
get_callbacks=get_callbacks,
)
if self.prompts.pre is not None:
chain = make_chain(
chain = self.llm_model.make_chain(
client=self._client,
prompt=self.prompts.pre,
llm_config=self.llm_config,
system_prompt=self.prompts.system,
)
pre = await chain(dict(question=answer.question), get_callbacks("pre"))
Expand All @@ -562,13 +598,11 @@ async def aquery(
"I cannot answer this question due to insufficient information."
)
else:
qa_chain = make_chain(
qa_chain = self.llm_model.make_chain(
client=self._client,
prompt=self.prompts.qa,
llm_config=self.llm_config,
system_prompt=self.prompts.system,
)
print(answer.context)
answer_text = await qa_chain(
dict(
context=answer.context,
Expand All @@ -577,7 +611,6 @@ async def aquery(
),
get_callbacks("answer"),
)
print(answer_text)
# it still happens
if "(Example2012)" in answer_text:
answer_text = answer_text.replace("(Example2012)", "")
Expand All @@ -598,10 +631,9 @@ async def aquery(
answer.references = bib_str

if self.prompts.post is not None:
chain = make_chain(
chain = self.llm_model.make_chain(
client=self._client,
prompt=self.prompts.post,
llm_config=self.llm_config,
system_prompt=self.prompts.system,
)
post = await chain(answer.model_dump(), get_callbacks("post"))
Expand Down
Loading

0 comments on commit 37b82f9

Please sign in to comment.