diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fd8fc9585..56deb5131 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,15 +2,23 @@ default_language_version: python: python3 repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - - id: trailing-whitespace + - id: check-added-large-files + - id: check-byte-order-marker + - id: check-case-conflict + - id: check-merge-conflict + - id: check-shebang-scripts-are-executable + - id: check-symlinks + - id: check-toml - id: check-yaml + - id: debug-statements + - id: detect-private-key - id: end-of-file-fixer - id: mixed-line-ending - - id: check-added-large-files + - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.0.270" + rev: v0.3.1 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -24,14 +32,18 @@ repos: - id: mypy args: [--pretty, --ignore-missing-imports] additional_dependencies: [types-requests, types-setuptools] - - repo: https://github.com/PyCQA/isort - rev: "5.12.0" - hooks: - - id: isort - args: [--profile=black, "--skip=__init__.py", "--filter-files"] - repo: https://github.com/pre-commit/mirrors-prettier rev: v3.1.0 hooks: - id: prettier additional_dependencies: - prettier@3.2.5 # SEE: https://github.com/pre-commit/pre-commit/issues/3133 + - repo: https://github.com/pappasam/toml-sort + rev: v0.23.1 + hooks: + - id: toml-sort-fix + - repo: https://github.com/codespell-project/codespell + rev: v2.2.6 + hooks: + - id: codespell + additional_dependencies: [".[toml]"] diff --git a/.ruff.toml b/.ruff.toml deleted file mode 100644 index 8f23e932b..000000000 --- a/.ruff.toml +++ /dev/null @@ -1,2 +0,0 @@ -# Allow lines to be as longer. -line-length = 180 diff --git a/README.md b/README.md index 5f1ab4e7a..56ba48446 100644 --- a/README.md +++ b/README.md @@ -378,7 +378,7 @@ It's not that different! This is similar to the tree response method in LlamaInd ### How is this different from LangChain? -There has been some great work on retrievers in langchain and you could say this is an example of a retreiver. +There has been some great work on retrievers in langchain and you could say this is an example of a retriever. ### Can I save or load? diff --git a/paperqa/__init__.py b/paperqa/__init__.py index 875837f39..4c55d5e0c 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -1,21 +1,21 @@ -from .docs import Answer, Docs, PromptCollection, Doc, Text, Context, print_callback -from .version import __version__ +from .docs import Answer, Context, Doc, Docs, PromptCollection, Text, print_callback from .llms import ( - LLMModel, + AnthropicLLMModel, EmbeddingModel, + HybridEmbeddingModel, LangchainEmbeddingModel, - OpenAIEmbeddingModel, LangchainLLMModel, - OpenAILLMModel, - AnthropicLLMModel, + LangchainVectorStore, LlamaEmbeddingModel, - HybridEmbeddingModel, - SparseEmbeddingModel, + LLMModel, + LLMResult, NumpyVectorStore, - LangchainVectorStore, + OpenAIEmbeddingModel, + OpenAILLMModel, SentenceTransformerEmbeddingModel, - LLMResult, + SparseEmbeddingModel, ) +from .version import __version__ __all__ = [ "Docs", diff --git a/paperqa/contrib/zotero.py b/paperqa/contrib/zotero.py index 7640e499f..7bc0f346e 100644 --- a/paperqa/contrib/zotero.py +++ b/paperqa/contrib/zotero.py @@ -9,7 +9,7 @@ try: from pyzotero import zotero except ImportError: - raise ImportError("Please install pyzotero: `pip install pyzotero`") + raise ImportError("Please install pyzotero: `pip install pyzotero`") # noqa: B904 from ..paths import PAPERQA_DIR from ..utils import StrPath, count_pdf_pages @@ -17,7 +17,7 @@ class ZoteroPaper(BaseModel): """A paper from Zotero. - Attributes + Attributes: ---------- key : str The citation key. @@ -65,9 +65,9 @@ def __init__( self, *, library_type: str = "user", - library_id: Optional[str] = None, - api_key: Optional[str] = None, - storage: Optional[StrPath] = None, + library_id: Optional[str] = None, # noqa: FA100 + api_key: Optional[str] = None, # noqa: FA100 + storage: Optional[StrPath] = None, # noqa: FA100 **kwargs, ): self.logger = logging.getLogger("ZoteroDB") @@ -81,7 +81,7 @@ def __init__( " from the text 'Your userID for use in API calls is [XXXXXX]'." " Then, set the environment variable ZOTERO_USER_ID to this value." ) - else: + else: # noqa: RET506 library_id = os.environ["ZOTERO_USER_ID"] if api_key is None: @@ -93,7 +93,7 @@ def __init__( " with access to your library." " Then, set the environment variable ZOTERO_API_KEY to this value." ) - else: + else: # noqa: RET506 api_key = os.environ["ZOTERO_API_KEY"] self.logger.info(f"Using library ID: {library_id} with type: {library_type}.") @@ -108,7 +108,7 @@ def __init__( library_type=library_type, library_id=library_id, api_key=api_key, **kwargs ) - def get_pdf(self, item: dict) -> Union[Path, None]: + def get_pdf(self, item: dict) -> Union[Path, None]: # noqa: FA100 """Gets a filename for a given Zotero key for a PDF. If the PDF is not found locally, the PDF will be downloaded to a local file at the correct key. @@ -120,7 +120,7 @@ def get_pdf(self, item: dict) -> Union[Path, None]: An item from `pyzotero`. Should have a `key` field, and also have an entry `links->attachment->attachmentType == application/pdf`. """ - if type(item) != dict: + if type(item) != dict: # noqa: E721 raise TypeError("Pass the full item of the paper. The item must be a dict.") pdf_key = _extract_pdf_key(item) @@ -137,17 +137,17 @@ def get_pdf(self, item: dict) -> Union[Path, None]: return pdf_path - def iterate( + def iterate( # noqa: C901, PLR0912 self, limit: int = 25, start: int = 0, - q: Optional[str] = None, - qmode: Optional[str] = None, - since: Optional[str] = None, - tag: Optional[str] = None, - sort: Optional[str] = None, - direction: Optional[str] = None, - collection_name: Optional[str] = None, + q: Optional[str] = None, # noqa: FA100 + qmode: Optional[str] = None, # noqa: FA100 + since: Optional[str] = None, # noqa: FA100 + tag: Optional[str] = None, # noqa: FA100 + sort: Optional[str] = None, # noqa: FA100 + direction: Optional[str] = None, # noqa: FA100 + collection_name: Optional[str] = None, # noqa: FA100 ): """Given a search query, this will lazily iterate over papers in a Zotero library, downloading PDFs as needed. @@ -210,8 +210,8 @@ def iterate( max_limit = 100 - items: List = [] - pdfs: List[Path] = [] + items: List = [] # noqa: FA100 + pdfs: List[Path] = [] # noqa: FA100 i = 0 actual_i = 0 num_remaining = limit @@ -247,7 +247,7 @@ def iterate( if no_pdf or is_duplicate: continue pdf = cast(Path, pdf) - title = item["data"]["title"] if "title" in item["data"] else "" + title = item["data"].get("title", "") if len(items) >= start: yield ZoteroPaper( key=_get_citation_key(item), @@ -277,12 +277,12 @@ def _get_collection_id(self, collection_name: str) -> str: """Get the collection id for a given collection name Raises ValueError if collection not found Args: - collection_name (str): The name of the collection + collection_name (str): The name of the collection. Returns: str: collection id - """ - # specfic collection + """ # noqa: D205 + # specific collection collections = self.collections() collection_id = "" @@ -326,9 +326,8 @@ def _get_citation_key(item: dict) -> str: return f"{last_name}_{short_title}_{date}_{item['key']}".replace(" ", "") -def _extract_pdf_key(item: dict) -> Union[str, None]: +def _extract_pdf_key(item: dict) -> Union[str, None]: # noqa: FA100 """Extract the PDF key from a Zotero item.""" - if "links" not in item: return None @@ -337,7 +336,7 @@ def _extract_pdf_key(item: dict) -> Union[str, None]: attachments = item["links"]["attachment"] - if type(attachments) != dict: + if type(attachments) != dict: # noqa: E721 # Find first attachment with attachmentType == application/pdf: for attachment in attachments: # TODO: This assumes there's only one PDF attachment. diff --git a/paperqa/docs.py b/paperqa/docs.py index 00dea7b56..9a5433dbb 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import json import os -import pprint +import pprint # noqa: F401 import re import tempfile from datetime import datetime @@ -50,12 +52,12 @@ # this is just to reduce None checks/type checks -async def empty_callback(result: LLMResult): +async def empty_callback(result: LLMResult): # noqa: ARG001 pass -async def print_callback(result: LLMResult): - pprint.pprint(result.model_dump()) +async def print_callback(result: LLMResult): # noqa: ARG001 + pass class Docs(BaseModel): @@ -68,7 +70,9 @@ class Docs(BaseModel): llm: str = "default" summary_llm: str | None = None llm_model: LLMModel = Field( - default=OpenAILLMModel(config=dict(model="gpt-4-0125-preview", temperature=0.1)) + default=OpenAILLMModel( + config={"model": "gpt-4-0125-preview", "temperature": 0.1} + ) ) summary_llm_model: LLMModel | None = Field(default=None, validate_default=True) embedding: str | None = "default" @@ -103,23 +107,19 @@ def __init__(self, **data): ): # convenience embedding_client = data["client"] + elif "embedding" in data and data["embedding"] != "default": + embedding_client = None else: - if "embedding" in data and data["embedding"] != "default": - embedding_client = None - else: - embedding_client = AsyncOpenAI() + embedding_client = AsyncOpenAI() if "client" in data: client = data.pop("client") + elif "llm_model" in data and data["llm_model"] is not None: + # except if it is an OpenAILLMModel + client = ( + AsyncOpenAI() if type(data["llm_model"]) == OpenAILLMModel else None + ) else: - # if llm_model is explicitly set, but not client then make it None - if "llm_model" in data and data["llm_model"] is not None: - # except if it is an OpenAILLMModel - if type(data["llm_model"]) == OpenAILLMModel: - client = AsyncOpenAI() - else: - client = None - else: - client = AsyncOpenAI() + client = AsyncOpenAI() # backwards compatibility if "doc_index" in data: data["docs_index"] = data.pop("doc_index") @@ -140,11 +140,11 @@ def __init__(self, **data): @model_validator(mode="before") @classmethod - def setup_alias_models(cls, data: Any) -> Any: + def setup_alias_models(cls, data: Any) -> Any: # noqa: C901, PLR0912 if isinstance(data, dict): if "llm" in data and data["llm"] != "default": if is_openai_model(data["llm"]): - data["llm_model"] = OpenAILLMModel(config=dict(model=data["llm"])) + data["llm_model"] = OpenAILLMModel(config={"model": data["llm"]}) elif data["llm"] == "langchain": data["llm_model"] = LangchainLLMModel() else: @@ -152,7 +152,7 @@ def setup_alias_models(cls, data: Any) -> Any: if "summary_llm" in data and data["summary_llm"] is not None: if is_openai_model(data["summary_llm"]): data["summary_llm_model"] = OpenAILLMModel( - config=dict(model=data["summary_llm"]) + config={"model": data["summary_llm"]} ) else: raise ValueError(f"Could not guess model type for {data['llm']}. ") @@ -199,7 +199,7 @@ def config_summary_llm_config(cls, data: Any) -> Any: and type(data.llm_model) == OpenAILLMModel ): data.summary_llm_model = OpenAILLMModel( - config=dict(model="gpt-3.5-turbo", temperature=0.1) + config={"model": "gpt-3.5-turbo", "temperature": 0.1} ) elif data.summary_llm_model is None: data.summary_llm_model = data.llm_model @@ -262,22 +262,16 @@ def set_client( client = AsyncOpenAI() self._client = client if embedding_client is None: - if type(client) == AsyncOpenAI: - embedding_client = client - else: - embedding_client = AsyncOpenAI() + embedding_client = client if type(client) == AsyncOpenAI else AsyncOpenAI() self._embedding_client = embedding_client Docs.make_llm_names_consistent(self) def _get_unique_name(self, docname: str) -> str: - """Create a unique name given proposed name""" + """Create a unique name given proposed name.""" suffix = "" while docname + suffix in self.docnames: # move suffix to next letter - if suffix == "": - suffix = "a" - else: - suffix = chr(ord(suffix) + 1) + suffix = "a" if suffix == "" else chr(ord(suffix) + 1) docname += suffix return docname @@ -357,7 +351,7 @@ async def aadd_url( """Add a document to the collection.""" import urllib.request - with urllib.request.urlopen(url) as f: + with urllib.request.urlopen(url) as f: # noqa: ASYNC100, S310 # need to wrap to enable seek file = BytesIO(f.read()) return await self.aadd_file( @@ -413,9 +407,13 @@ async def aadd( texts = read_doc(path, fake_doc, chunk_chars=chunk_chars, overlap=100) if len(texts) == 0: raise ValueError(f"Could not read document {path}. Is it empty?") - chain_result = await cite_chain(dict(text=texts[0].text), None) + chain_result = await cite_chain({"text": texts[0].text}, None) citation = chain_result.text - if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation: + if ( + len(citation) < 3 # noqa: PLR2004 + or "Unknown" in citation + or "insufficient" in citation + ): citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}" if docname is None: @@ -441,7 +439,7 @@ async def aadd( # loose check to see if document was loaded if ( len(texts) == 0 - or len(texts[0].text) < 10 + or len(texts[0].text) < 10 # noqa: PLR2004 or (not disable_check and not maybe_is_text(texts[0].text)) ): raise ValueError( @@ -522,7 +520,7 @@ async def adoc_match( query: str, k: int = 25, rerank: bool | None = None, - get_callbacks: CallbackFactory = lambda x: None, + get_callbacks: CallbackFactory = lambda x: None, # noqa: ARG005 answer: Answer | None = None, # used for tracking tokens ) -> set[DocKey]: """Return a list of dockeys that match the query.""" @@ -556,7 +554,7 @@ async def adoc_match( ) papers = [f"{d.docname}: {d.citation}" for d in matched_docs] result = await chain( - dict(question=query, papers="\n".join(papers)), + {"question": query, "papers": "\n".join(papers)}, get_callbacks("filter"), ) if answer: @@ -565,10 +563,10 @@ async def adoc_match( await self.llm_result_callback(result) if answer: answer.add_tokens(result) - return set([d.dockey for d in matched_docs if d.docname in str(result)]) + return {d.dockey for d in matched_docs if d.docname in str(result)} except AttributeError: pass - return set([d.dockey for d in matched_docs]) + return {d.dockey for d in matched_docs} def _build_texts_index(self, keys: set[DocKey] | None = None): texts = self.texts @@ -594,7 +592,7 @@ def get_evidence( answer: Answer, k: int = 10, max_sources: int = 5, - get_callbacks: CallbackFactory = lambda x: None, + get_callbacks: CallbackFactory = lambda x: None, # noqa: ARG005 detailed_citations: bool = False, disable_vector_search: bool = False, ) -> Answer: @@ -609,12 +607,12 @@ def get_evidence( ) ) - async def aget_evidence( + async def aget_evidence( # noqa: C901, PLR0915 self, answer: Answer, k: int = 10, # Number of evidence pieces to retrieve max_sources: int = 5, # Number of scored contexts to use - get_callbacks: CallbackFactory = lambda x: None, + get_callbacks: CallbackFactory = lambda x: None, # noqa: ARG005 detailed_citations: bool = False, disable_vector_search: bool = False, ) -> Answer: @@ -650,7 +648,7 @@ async def aget_evidence( # now finally cut down matches = matches[:k] - async def process(match): + async def process(match): # noqa: C901, PLR0912 callbacks = get_callbacks("evidence:" + match.name) citation = match.doc.citation # needed empties for failures/skips @@ -683,12 +681,12 @@ async def process(match): # http code in the exception try: llm_result = await summary_chain( - dict( - question=answer.question, - citation=citation, - summary_length=answer.summary_length, - text=match.text, - ), + { + "question": answer.question, + "citation": citation, + "summary_length": answer.summary_length, + "text": match.text, + }, callbacks, ) llm_result.answer_id = answer.id @@ -698,7 +696,7 @@ async def process(match): except Exception as e: if guess_is_4xx(str(e)): return None, llm_result - raise e + raise success = True if self.prompts.summary_json: try: @@ -779,7 +777,7 @@ def query( length_prompt="about 100 words", answer: Answer | None = None, key_filter: bool | None = None, - get_callbacks: CallbackFactory = lambda x: None, + get_callbacks: CallbackFactory = lambda x: None, # noqa: ARG005 ) -> Answer: return get_loop().run_until_complete( self.aquery( @@ -793,7 +791,7 @@ def query( ) ) - async def aquery( + async def aquery( # noqa: C901, PLR0912, PLR0915 self, query: str, k: int = 10, @@ -801,7 +799,7 @@ async def aquery( length_prompt: str = "about 100 words", answer: Answer | None = None, key_filter: bool | None = None, - get_callbacks: CallbackFactory = lambda x: None, + get_callbacks: CallbackFactory = lambda x: None, # noqa: ARG005 ) -> Answer: if k < max_sources: raise ValueError("k should be greater than max_sources") @@ -830,7 +828,7 @@ async def aquery( prompt=self.prompts.pre, system_prompt=self.prompts.system, ) - pre = await chain(dict(question=answer.question), get_callbacks("pre")) + pre = await chain({"question": answer.question}, get_callbacks("pre")) pre.name = "pre" pre.answer_id = answer.id await self.llm_result_callback(pre) @@ -838,8 +836,8 @@ async def aquery( answer.context = ( answer.context + "\n\nExtra background information:" + str(pre) ) - bib = dict() - if len(answer.context) < 10: # and not self.memory: + bib = {} + if len(answer.context) < 10: # and not self.memory: # noqa: PLR2004 answer_text = ( "I cannot answer this question due to insufficient information." ) @@ -850,11 +848,11 @@ async def aquery( system_prompt=self.prompts.system, ) answer_result = await qa_chain( - dict( - context=answer.context, - answer_length=answer.answer_length, - question=answer.question, - ), + { + "context": answer.context, + "answer_length": answer.answer_length, + "question": answer.question, + }, get_callbacks("answer"), ) answer_result.name = "answer" diff --git a/paperqa/llms.py b/paperqa/llms.py index a66ded30d..a7ed67da0 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import asyncio import datetime import re from abc import ABC, abstractmethod from inspect import signature -from typing import Any, AsyncGenerator, Callable, Coroutine, Sequence, Type, cast +from typing import Any, AsyncGenerator, Callable, Coroutine, Sequence, cast import numpy as np import tiktoken @@ -46,7 +48,7 @@ # return model_name in model_arr or model_name in complete_model_arr -def guess_model_type(model_name: str) -> str: +def guess_model_type(model_name: str) -> str: # noqa: PLR0911 if model_name.startswith("babbage"): return "completion" if model_name.startswith("davinci"): @@ -63,20 +65,18 @@ def guess_model_type(model_name: str) -> str: def is_openai_model(model_name) -> bool: - return ( - model_name.startswith("gpt-") - or model_name.startswith("babbage") - or model_name.startswith("davinci") - ) + return model_name.startswith(("gpt-", "babbage", "davinci")) -def process_llm_config(llm_config: dict, max_token_name: str = "max_tokens") -> dict: - """Remove model_type and try to set max_tokens""" +def process_llm_config( + llm_config: dict, max_token_name: str = "max_tokens" # noqa: S107 +) -> dict: + """Remove model_type and try to set max_tokens.""" result = {k: v for k, v in llm_config.items() if k != "model_type"} if max_token_name not in result or result[max_token_name] == -1: model = llm_config["model"] # now we guess - we could use tiktoken to count, - # but do have the initative right now + # but do have the initiative right now if model.startswith("gpt-4") or ( model.startswith("gpt-3.5") and "1106" in model ): @@ -89,7 +89,7 @@ def process_llm_config(llm_config: dict, max_token_name: str = "max_tokens") -> async def embed_documents( client: AsyncOpenAI, texts: list[str], embedding_model: str ) -> list[list[float]]: - """Embed a list of documents with batching""" + """Embed a list of documents with batching.""" if client is None: raise ValueError( "Your client is None - did you forget to set it after pickling?" @@ -116,20 +116,19 @@ async def embed_documents(self, client: Any, texts: list[str]) -> list[list[floa class SparseEmbeddingModel(EmbeddingModel): - """This is a very simple keyword search model - probably best to be mixed with others""" + """This is a very simple keyword search model - probably best to be mixed with others.""" name: str = "sparse-embed" ndim: int = 256 enc: Any = Field(default_factory=lambda: tiktoken.get_encoding("cl100k_base")) - async def embed_documents(self, client, texts) -> list[list[float]]: + async def embed_documents(self, client, texts) -> list[list[float]]: # noqa: ARG002 enc_batch = self.enc.encode_ordinary_batch(texts) # now get frequency of each token rel to length - packed = [ + return [ np.bincount([xi % self.ndim for xi in x], minlength=self.ndim) / len(x) for x in enc_batch ] - return packed class HybridEmbeddingModel(EmbeddingModel): @@ -154,7 +153,8 @@ async def acomplete(self, client: Any, prompt: str) -> str: async def acomplete_iter(self, client: Any, prompt: str) -> Any: """Return an async generator that yields chunks of the completion. - I cannot get mypy to understand the override, so marked as Any""" + I cannot get mypy to understand the override, so marked as Any + """ raise NotImplementedError async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: @@ -163,16 +163,17 @@ async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: """Return an async generator that yields chunks of the completion. - I cannot get mypy to understand the override, so marked as Any""" + I cannot get mypy to understand the override, so marked as Any + """ raise NotImplementedError - def infer_llm_type(self, client: Any) -> str: + def infer_llm_type(self, client: Any) -> str: # noqa: ARG002 return "completion" def count_tokens(self, text: str) -> int: return len(text) // 4 # gross approximation - def make_chain( + def make_chain( # noqa: C901, PLR0915 self, client: Any, prompt: str, @@ -181,7 +182,7 @@ def make_chain( ) -> Callable[ [dict, list[Callable[[str], None]] | None], Coroutine[Any, Any, LLMResult] ]: - """Create a function to execute a batch of prompts + """Create a function to execute a batch of prompts. This replaces the previous use of langchain for combining prompts and LLMs. @@ -201,12 +202,13 @@ def make_chain( if self.llm_type is None: self.llm_type = self.infer_llm_type(client) if self.llm_type == "chat": - system_message_prompt = dict(role="system", content=system_prompt) - human_message_prompt = dict(role="user", content=prompt) - if skip_system: - chat_prompt = [human_message_prompt] - else: - chat_prompt = [system_message_prompt, human_message_prompt] + system_message_prompt = {"role": "system", "content": system_prompt} + human_message_prompt = {"role": "user", "content": prompt} + chat_prompt = ( + [human_message_prompt] + if skip_system + else [system_message_prompt, human_message_prompt] + ) async def execute( data: dict, @@ -219,8 +221,8 @@ async def execute( ) messages = [] for m in chat_prompt: - messages.append( - dict(role=m["role"], content=m["content"].format(**data)) + messages.append( # noqa: PERF401 + {"role": m["role"], "content": m["content"].format(**data)} ) result.prompt = messages result.prompt_count = sum( @@ -254,11 +256,10 @@ async def execute( return result return execute - elif self.llm_type == "completion": - if skip_system: - completion_prompt = prompt - else: - completion_prompt = system_prompt + "\n\n" + prompt + elif self.llm_type == "completion": # noqa: RET505 + completion_prompt = ( + prompt if skip_system else system_prompt + "\n\n" + prompt + ) async def execute( data: dict, callbacks: list[Callable] | None = None @@ -306,7 +307,7 @@ async def execute( class OpenAILLMModel(LLMModel): - config: dict = Field(default=dict(model="gpt-3.5-turbo", temperature=0.1)) + config: dict = Field(default={"model": "gpt-3.5-turbo", "temperature": 0.1}) name: str = "gpt-3.5-turbo" def _check_client(self, client: Any) -> AsyncOpenAI: @@ -315,7 +316,7 @@ def _check_client(self, client: Any) -> AsyncOpenAI: "Your client is None - did you forget to set it after pickling?" ) if not isinstance(client, AsyncOpenAI): - raise ValueError( + raise ValueError( # noqa: TRY004 f"Your client is not a required AsyncOpenAI client. It is a {type(client)}" ) return client @@ -375,7 +376,7 @@ async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: class AnthropicLLMModel(LLMModel): config: dict = Field( - default=dict(model="claude-3-sonnet-20240229", temperature=0.1) + default={"model": "claude-3-sonnet-20240229", "temperature": 0.1} ) name: str = "claude-3-sonnet-20240229" @@ -390,7 +391,7 @@ def _check_client(self, client: Any) -> AsyncAnthropic: "Your client is None - did you forget to set it after pickling?" ) if not isinstance(client, AsyncAnthropic): - raise ValueError( + raise ValueError( # noqa: TRY004 f"Your client is not a required AsyncAnthropic client. It is a {type(client)}" ) return client @@ -464,18 +465,21 @@ async def embed_documents(self, client: Any, texts: list[str]) -> list[list[floa cast(AsyncOpenAI, client) async def process(texts: list[str]) -> list[float]: - for i in range(3): + for i in range(3): # noqa: B007 # access httpx client directly to avoid type casting response = await client._client.post( client.base_url.join("../embedding"), json={"content": texts} ) body = response.json() if len(texts) == 1: - if type(body) != dict or body.get("embedding") is None: + if ( + type(body) != dict # noqa: E721 + or body.get("embedding") is None + ): continue return [body["embedding"]] - else: - if type(body) != list or body[0] != "results": + else: # noqa: RET505 + if type(body) != list or body[0] != "results": # noqa: E721 continue return [e["embedding"] for e in body[1]] raise ValueError("Failed to embed documents - response was ", body) @@ -496,16 +500,19 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) try: from sentence_transformers import SentenceTransformer - except ImportError: - raise ImportError("Please install sentence-transformers to use this model") + except ImportError as exc: + raise ImportError( + "Please install sentence-transformers to use this model" + ) from exc self._model = SentenceTransformer(self.name) - async def embed_documents(self, client: Any, texts: list[str]) -> list[list[float]]: + async def embed_documents( + self, client: Any, texts: list[str] # noqa: ARG002 + ) -> list[list[float]]: from sentence_transformers import SentenceTransformer - embeddings = cast(SentenceTransformer, self._model).encode(texts) - return embeddings + return cast(SentenceTransformer, self._model).encode(texts) def cosine_similarity(a, b): @@ -514,7 +521,7 @@ def cosine_similarity(a, b): class VectorStore(BaseModel, ABC): - """Interface for vector store - very similar to LangChain's VectorStore to be compatible""" + """Interface for vector store - very similar to LangChain's VectorStore to be compatible.""" embedding_model: EmbeddingModel = Field(default=OpenAIEmbeddingModel()) # can be tuned for different tasks @@ -535,7 +542,7 @@ async def similarity_search( def clear(self) -> None: pass - async def max_marginal_relevance_search( + async def max_marginal_relevance_search( # noqa: D417 self, client: Any, query: str, k: int, fetch_k: int ) -> tuple[Sequence[Embeddable], list[float]]: """Vectorized implementation of Maximal Marginal Relevance (MMR) search. @@ -625,7 +632,7 @@ async def similarity_search( class LangchainLLMModel(LLMModel): - """A wrapper around the wrapper langchain""" + """A wrapper around the wrapper langchain.""" name: str = "langchain" @@ -673,7 +680,7 @@ async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: class LangchainEmbeddingModel(EmbeddingModel): - """A wrapper around the wrapper langchain""" + """A wrapper around the wrapper langchain.""" name: str = "langchain" @@ -682,7 +689,7 @@ async def embed_documents(self, client: Any, texts: list[str]) -> list[list[floa class LangchainVectorStore(VectorStore): - """A wrapper around the wrapper langchain + """A wrapper around the wrapper langchain. Note that if you this is cleared (e.g., by `Docs` having `jit_texts_index` set to True), this will calls the `from_texts` class method on the `store`. This means that any non-default @@ -692,7 +699,7 @@ class LangchainVectorStore(VectorStore): _store_builder: Any | None = None _store: Any | None = None # JIT Generics - store the class type (Doc or Text) - class_type: Type[Embeddable] = Field(default=Embeddable) + class_type: type[Embeddable] = Field(default=Embeddable) model_config = ConfigDict(extra="forbid") def __init__(self, **data): @@ -719,12 +726,12 @@ def candidate(x, y): def check_store_builder(cls, builder: Any) -> Any: # check it is a callable if not callable(builder): - raise ValueError("store_builder must be callable") + raise ValueError("store_builder must be callable") # noqa: TRY004 # check it takes two arguments # we don't use type hints because it could be # a partial sig = signature(builder) - if len(sig.parameters) != 2: + if len(sig.parameters) != 2: # noqa: PLR2004 raise ValueError("store_builder must take two arguments") return builder @@ -746,13 +753,13 @@ def add_texts_and_embeddings(self, texts: Sequence[Embeddable]) -> None: raise ValueError("You must set store_builder before adding texts") self.class_type = type(texts[0]) if self.class_type == Text: - vec_store_text_and_embeddings = list( - map(lambda x: (x.text, x.embedding), cast(list[Text], texts)) - ) + vec_store_text_and_embeddings = [ + (x.text, x.embedding) for x in cast(list[Text], texts) + ] elif self.class_type == Doc: - vec_store_text_and_embeddings = list( - map(lambda x: (x.citation, x.embedding), cast(list[Doc], texts)) - ) + vec_store_text_and_embeddings = [ + (x.citation, x.embedding) for x in cast(list[Doc], texts) + ] else: raise ValueError("Only embeddings of type Text are supported") if self._store is None: @@ -768,7 +775,7 @@ def add_texts_and_embeddings(self, texts: Sequence[Embeddable]) -> None: ) async def similarity_search( - self, client: Any, query: str, k: int + self, client: Any, query: str, k: int # noqa: ARG002 ) -> tuple[Sequence[Embeddable], list[float]]: if self._store is None: return [], [] @@ -795,16 +802,16 @@ def get_score(text: str) -> int: score = re.search(r"([0-9]+)\w*\/", text) if score: s = int(score.group(1)) - if s > 10: + if s > 10: # noqa: PLR2004 s = int(s / 10) # sometimes becomes out of 100 return s last_few = text[-15:] scores = re.findall(r"([0-9]+)", last_few) if scores: s = int(scores[-1]) - if s > 10: + if s > 10: # noqa: PLR2004 s = int(s / 10) # sometimes becomes out of 100 return s - if len(text) < 100: + if len(text) < 100: # noqa: PLR2004 return 1 return 5 diff --git a/paperqa/prompts.py b/paperqa/prompts.py index 44d12b612..690e8974a 100644 --- a/paperqa/prompts.py +++ b/paperqa/prompts.py @@ -11,7 +11,7 @@ ) summary_json_prompt = ( - "Excerpt from {citation}\n\n----\n\n{text}\n\n----\n\n" "Question: {question}\n\n" + "Excerpt from {citation}\n\n----\n\n{text}\n\n----\n\nQuestion: {question}\n\n" ) qa_prompt = ( diff --git a/paperqa/readers.py b/paperqa/readers.py index f23630863..06b614487 100644 --- a/paperqa/readers.py +++ b/paperqa/readers.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from math import ceil from pathlib import Path -from typing import List +from typing import List # noqa: F401 import tiktoken from html2text import html2text @@ -8,13 +10,13 @@ from .types import Doc, Text -def parse_pdf_fitz(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List[Text]: +def parse_pdf_fitz(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> list[Text]: import fitz file = fitz.open(path) split = "" - pages: List[str] = [] - texts: List[Text] = [] + pages: list[str] = [] + texts: list[Text] = [] for i in range(file.page_count): page = file.load_page(i) split += page.get_text("text", sort=True) @@ -41,14 +43,14 @@ def parse_pdf_fitz(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List return texts -def parse_pdf(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List[Text]: +def parse_pdf(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> list[Text]: import pypdf - pdfFileObj = open(path, "rb") + pdfFileObj = open(path, "rb") # noqa: SIM115 pdfReader = pypdf.PdfReader(pdfFileObj) split = "" - pages: List[str] = [] - texts: List[Text] = [] + pages: list[str] = [] + texts: list[Text] = [] for i, page in enumerate(pdfReader.pages): split += page.extract_text() pages.append(str(i + 1)) @@ -76,11 +78,11 @@ def parse_pdf(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List[Text def parse_txt( path: Path, doc: Doc, chunk_chars: int, overlap: int, html: bool = False -) -> List[Text]: +) -> list[Text]: """Parse a document into chunks, based on tiktoken encoding. NOTE: We get some byte continuation errors. - Currnetly ignored, but should explore more to make sure we + Currently ignored, but should explore more to make sure we don't miss anything. """ try: @@ -120,11 +122,10 @@ def parse_txt( return texts -def parse_code_txt(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List[Text]: +def parse_code_txt(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> list[Text]: """Parse a document into chunks, based on line numbers (for code).""" - split = "" - texts: List[Text] = [] + texts: list[Text] = [] last_line = 0 with open(path) as f: @@ -157,7 +158,7 @@ def read_doc( chunk_chars: int = 3000, overlap: int = 100, force_pypdf: bool = False, -) -> List[Text]: +) -> list[Text]: """Parse a document into chunks.""" str_path = str(path) if str_path.endswith(".pdf"): diff --git a/paperqa/types.py b/paperqa/types.py index ed1384c8f..fd203589f 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Callable from uuid import UUID, uuid4 @@ -75,7 +77,7 @@ def __missing__(self, key: str) -> str: def get_formatted_variables(s: str) -> set[str]: - """Returns the set of variables implied by the format string""" + """Returns the set of variables implied by the format string.""" format_dict = _FormatDict() s.format_map(format_dict) return format_dict.key_set @@ -133,9 +135,8 @@ def check_select(cls, v: str) -> str: @field_validator("pre") @classmethod def check_pre(cls, v: str | None) -> str | None: - if v is not None: - if set(get_formatted_variables(v)) != set(["question"]): - raise ValueError("Pre prompt must have input variables: question") + if v is not None and set(get_formatted_variables(v)) != {"question"}: + raise ValueError("Pre prompt must have input variables: question") return v @field_validator("post") @@ -160,7 +161,7 @@ class Context(BaseModel): ) -def __str__(self) -> str: +def __str__(self) -> str: # noqa: N807 """Return the context as a string.""" return self.context @@ -203,11 +204,11 @@ def used_contexts(self) -> set[str]: return get_citenames(self.formatted_answer) def get_citation(self, name: str) -> str: - """Return the formatted citation for the gien docname.""" + """Return the formatted citation for the given docname.""" try: doc = next(filter(lambda x: x.text.name == name, self.contexts)).text.doc except StopIteration: - raise ValueError(f"Could not find docname {name} in contexts") + raise ValueError(f"Could not find docname {name} in contexts") # noqa: B904 return doc.citation def add_tokens(self, result: LLMResult): diff --git a/paperqa/utils.py b/paperqa/utils.py index de6162bd8..86974bcfc 100644 --- a/paperqa/utils.py +++ b/paperqa/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import inspect import math @@ -13,7 +15,7 @@ def name_in_text(name: str, text: str) -> bool: sname = name.strip() - pattern = r"\b({0})\b(?!\w)".format(re.escape(sname)) + pattern = rf"\b({re.escape(sname)})\b(?!\w)" if re.search(pattern, text): return True return False @@ -44,12 +46,7 @@ def maybe_is_pdf(file: BinaryIO) -> bool: def maybe_is_html(file: BinaryIO) -> bool: magic_number = file.read(4) file.seek(0) - return ( - magic_number == b" float: @@ -79,7 +76,7 @@ def md5sum(file_path: StrPath) -> str: import hashlib with open(file_path, "rb") as f: - return hashlib.md5(f.read()).hexdigest() + return hashlib.md5(f.read()).hexdigest() # noqa: S324 async def gather_with_concurrency(n: int, coros: list[Coroutine]) -> list[Any]: @@ -103,8 +100,7 @@ def strip_citations(text: str) -> str: # Combined regex for identifying citations (see unit tests for examples) citation_regex = r"\b[\w\-]+\set\sal\.\s\([0-9]{4}\)|\((?:[^\)]*?[a-zA-Z][^\)]*?[0-9]{4}[^\)]*?)\)" # Remove the citations from the text - text = re.sub(citation_regex, "", text, flags=re.MULTILINE) - return text + return re.sub(citation_regex, "", text, flags=re.MULTILINE) def get_citenames(text: str) -> set[str]: @@ -117,12 +113,12 @@ def get_citenames(text: str) -> set[str]: results.extend(none_results) values = [] for citation in results: - citation = citation.strip("() ") + citation = citation.strip("() ") # noqa: PLW2901 for c in re.split(",|;", citation): if c == "Extra background information": continue # remove leading/trailing spaces - c = c.strip() + c = c.strip() # noqa: PLW2901 values.append(c) return set(values) @@ -141,13 +137,13 @@ def extract_doi(reference: str) -> str: # If DOI is found in the reference, return the DOI link if doi_match: return "https://doi.org/" + doi_match.group() - else: + else: # noqa: RET505 return "" def batch_iter(iterable: list, n: int = 1) -> Iterator[list]: """ - Batch an iterable into chunks of size n + Batch an iterable into chunks of size n. :param iterable: The iterable to batch :param n: The size of the batches @@ -160,7 +156,7 @@ def batch_iter(iterable: list, n: int = 1) -> Iterator[list]: def flatten(iteratble: list) -> list: """ - Flatten a list of lists + Flatten a list of lists. :param l: The list of lists to flatten :return: A flattened list @@ -180,6 +176,6 @@ def get_loop() -> asyncio.AbstractEventLoop: def is_coroutine_callable(obj): if inspect.isfunction(obj): return inspect.iscoroutinefunction(obj) - elif callable(obj): + elif callable(obj): # noqa: RET505 return inspect.iscoroutinefunction(obj.__call__) return False diff --git a/pyproject.toml b/pyproject.toml index b0ff3d329..4e193fb10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,10 @@ +[tool.codespell] +check-filenames = true +check-hidden = true +# SEE: https://github.com/codespell-project/codespell/issues/1212#issuecomment-1744768533 +ignore-regex = ".{1024}|.*codespell-ignore.*" +ignore-words-list = "aadd,ser" + [tool.mypy] # Type-checks the interior of functions without type annotations. check_untyped_defs = true @@ -47,3 +54,89 @@ module = [ "pyzotero", # SEE: https://github.com/urschrei/pyzotero/issues/110 "sentence_transformers", # SEE: https://github.com/UKPLab/sentence-transformers/issues/1723 ] + +[tool.ruff] +# Line length to use when enforcing long-lines violations (like `E501`). +line-length = 120 +# Enable application of unsafe fixes. +unsafe-fixes = true + +[tool.ruff.lint] +# List of rule codes that are unsupported by Ruff, but should be preserved when +# (e.g.) validating # noqa directives. Useful for retaining # noqa directives +# that cover plugins not yet implemented by Ruff. +ignore = [ + "ANN", # Don't care to enforce typing + "BLE001", # Don't care to enforce blind exception catching + "COM812", # Trailing comma with black leads to wasting lines + "D100", # D100, D101, D102, D103, D104, D105, D106, D107: don't always need docstrings + "D101", + "D102", + "D103", + "D104", + "D105", + "D106", + "D107", + "D203", # Keep docstring next to the class definition (covered by D211) + "D212", # Summary should be on second line (opposite of D213) + "D402", # It's nice to reuse the method name + "D406", # Google style requires ":" at end + "D407", # We aren't using numpy style + "D413", # Blank line after last section. -> No blank line + "DTZ", # Don't care to have timezone safety + "EM", # Overly pedantic + "ERA001", # Don't care to prevent commented code + "FBT001", # FBT001, FBT002: overly pedantic + "FBT002", + "FIX", # Don't care to prevent TODO, FIXME, etc. + "FLY002", # Can be less readable + "G004", # f-strings are convenient + "INP001", # Can use namespace packages + "N803", # Want to use 'N', or 'L', + "N806", # Want to use 'N', or 'L', + "PLR0913", + "PTH", # Overly pedantic + "S311", # Ok to use python random + "SLF001", # Overly pedantic + "T201", # Overly pedantic + "TCH001", # TCH001, TCH002, TCH003: don't care to enforce type checking blocks + "TCH002", + "TCH003", + "TD002", # Don't care for TODO author + "TD003", # Don't care for TODO links + "TID252", # Allow relative imports for packaging + "TRY003", # Overly pedantic +] +select = ["ALL"] +unfixable = [ + "B007", # While debugging, unused loop variables can be useful + "ERA001", # While debugging, temporarily commenting code can be useful + "F401", # While debugging, unused imports can be useful + "F841", # While debugging, unused locals can be useful +] + +[tool.ruff.lint.flake8-annotations] +mypy-init-return = true + +[tool.ruff.lint.per-file-ignores] +"tests/*.py" = [ + "PLR2004", # Tests can have magic values + "S101", # Tests can have assertions +] + +[tool.ruff.lint.pycodestyle] +# The maximum line length to allow for line-length violations within +# documentation (W505), including standalone comments. +max-doc-length = 120 # Match line-length + +[tool.ruff.lint.pydocstyle] +# Whether to use Google-style or NumPy-style conventions or the PEP257 +# defaults when analyzing docstring sections. +convention = "google" + +[tool.tomlsort] +all = true +in_place = true +spaces_before_inline_comment = 2 # Match Python PEP 8 +spaces_indent_inline_array = 4 # Match Python PEP 8 +trailing_comma_inline_array = true diff --git a/setup.py b/setup.py index e94615bb2..6bc87745c 100644 --- a/setup.py +++ b/setup.py @@ -2,9 +2,9 @@ # for typing __version__ = "" -exec(open("paperqa/version.py").read()) +exec(open("paperqa/version.py").read()) # noqa: S102, SIM115 -with open("README.md", "r", encoding="utf-8") as fh: +with open("README.md", encoding="utf-8") as fh: long_description = fh.read() setup( diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 0c48ebf58..f94a2d0c7 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import pickle import tempfile @@ -171,7 +173,7 @@ def test_ablations(): docs.add_file(f, "Wellawatte et al, XAI Review, 2023") answer = docs.get_evidence( Answer( - question="Which page is the statement 'Deep learning (DL) is advancing the boundaries of computational" + question="Which page is the statement 'Deep learning (DL) is advancing the boundaries of computational" # noqa: ISC003 + "chemistry because it can accurately model non-linear structure-function relationships.' on?" ) ) @@ -180,7 +182,7 @@ def test_ablations(): ), "summarization not ablated" answer = docs.get_evidence( Answer( - question="Which page is the statement 'Deep learning (DL) is advancing the boundaries of computational" + question="Which page is the statement 'Deep learning (DL) is advancing the boundaries of computational" # noqa: ISC003 + "chemistry because it can accurately model non-linear structure-function relationships.' on?" ), disable_vector_search=True, @@ -195,7 +197,7 @@ def test_location_awareness(): docs.add_file(f, "Wellawatte et al, XAI Review, 2023") answer = docs.get_evidence( Answer( - question="Which page is the statement 'Deep learning (DL) is advancing the boundaries of computational" + question="Which page is the statement 'Deep learning (DL) is advancing the boundaries of computational" # noqa: ISC003 + "chemistry because it can accurately model non-linear structure-function relationships.' on?" ), detailed_citations=True, @@ -207,7 +209,9 @@ def test_maybe_is_text(): assert maybe_is_text("This is a test. The sample conc. was 1.0 mM (at 245 ^F)") assert not maybe_is_text("\\C0\\C0\\B1\x00") # get front page of wikipedia - r = requests.get("https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day" + ) assert maybe_is_text(r.text) assert maybe_is_html(BytesIO(r.text.encode())) @@ -407,7 +411,7 @@ def test_extract_score(): class TestChains(IsolatedAsyncioTestCase): async def test_chain_completion(self): client = AsyncOpenAI() - llm = OpenAILLMModel(config=dict(model="babbage-002", temperature=0.2)) + llm = OpenAILLMModel(config={"model": "babbage-002", "temperature": 0.2}) call = llm.make_chain( client, "The {animal} says", @@ -418,20 +422,20 @@ async def test_chain_completion(self): def accum(x): outputs.append(x) - completion = await call(dict(animal="duck"), callbacks=[accum]) # type: ignore[call-arg] + completion = await call({"animal": "duck"}, callbacks=[accum]) # type: ignore[call-arg] assert completion.seconds_to_first_token > 0 assert completion.prompt_count > 0 assert completion.completion_count > 0 assert str(completion) == "".join(outputs) - completion = await call(dict(animal="duck")) # type: ignore[call-arg] + completion = await call({"animal": "duck"}) # type: ignore[call-arg] assert completion.seconds_to_first_token == 0 assert completion.seconds_to_last_token > 0 async def test_chain_chat(self): client = AsyncOpenAI() llm = OpenAILLMModel( - config=dict(temperature=0, model="gpt-3.5-turbo", max_tokens=56) + config={"temperature": 0, "model": "gpt-3.5-turbo", "max_tokens": 56} ) call = llm.make_chain( client, @@ -443,21 +447,21 @@ async def test_chain_chat(self): def accum(x): outputs.append(x) - completion = await call(dict(animal="duck"), callbacks=[accum]) # type: ignore[call-arg] + completion = await call({"animal": "duck"}, callbacks=[accum]) # type: ignore[call-arg] assert completion.seconds_to_first_token > 0 assert completion.prompt_count > 0 assert completion.completion_count > 0 assert str(completion) == "".join(outputs) - completion = await call(dict(animal="duck")) # type: ignore[call-arg] + completion = await call({"animal": "duck"}) # type: ignore[call-arg] assert completion.seconds_to_first_token == 0 assert completion.seconds_to_last_token > 0 # check with mixed callbacks - async def ac(x): + async def ac(x): # noqa: ARG001 pass - completion = await call(dict(animal="duck"), callbacks=[accum, ac]) # type: ignore[call-arg] + completion = await call({"animal": "duck"}, callbacks=[accum, ac]) # type: ignore[call-arg] async def test_anthropic_chain(self): try: @@ -477,13 +481,13 @@ def accum(x): outputs.append(x) outputs: list[str] = [] - completion = await call(dict(animal="duck"), callbacks=[accum]) # type: ignore[call-arg] + completion = await call({"animal": "duck"}, callbacks=[accum]) # type: ignore[call-arg] assert completion.seconds_to_first_token > 0 assert completion.prompt_count > 0 assert completion.completion_count > 0 assert str(completion) == "".join(outputs) - completion = await call(dict(animal="duck")) # type: ignore[call-arg] + completion = await call({"animal": "duck"}) # type: ignore[call-arg] assert completion.seconds_to_first_token == 0 assert completion.seconds_to_last_token > 0 @@ -504,7 +508,9 @@ def test_evidence(): doc_path = "example.html" with open(doc_path, "w", encoding="utf-8") as f: # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)" + ) f.write(r.text) docs = Docs() docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") # type: ignore[arg-type] @@ -534,14 +540,16 @@ def test_json_evidence(): doc_path = "example.html" with open(doc_path, "w", encoding="utf-8") as f: # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)" + ) f.write(r.text) summary_llm = OpenAILLMModel( - config=dict( - model="gpt-3.5-turbo-1106", - response_format=dict(type="json_object"), - temperature=0.0, - ) + config={ + "model": "gpt-3.5-turbo-1106", + "response_format": {"type": "json_object"}, + "temperature": 0.0, + } ) docs = Docs( prompts=PromptCollection(json_summary=True), @@ -567,18 +575,20 @@ def test_custom_json_props(): doc_path = "example.html" with open(doc_path, "w", encoding="utf-8") as f: # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)" + ) f.write(r.text) summary_llm = OpenAILLMModel( - config=dict( - model="gpt-3.5-turbo-0125", - response_format=dict(type="json_object"), - temperature=0.0, - ) + config={ + "model": "gpt-3.5-turbo-0125", + "response_format": {"type": "json_object"}, + "temperature": 0.0, + } ) my_prompts = PromptCollection( json_summary=True, - summary_json_system="Provide a summary of the excerpt that could help answer the question based on the excerpt. " + summary_json_system="Provide a summary of the excerpt that could help answer the question based on the excerpt. " # noqa: E501 "The excerpt may be irrelevant. Do not directly answer the question - only summarize relevant information. " "Respond with the following JSON format:\n\n" '{{\n"summary": "...",\n"person_name": "...",\n"relevance_score": "..."}}\n\n' @@ -626,7 +636,7 @@ def test_answer_attributes(): assert len(used_citations) > 0 assert len(used_citations) < len(answer.contexts) assert ( - answer.get_citation(list(used_citations)[0]) + answer.get_citation(next(iter(used_citations))) == "WikiMedia Foundation, 2023, Accessed now" ) @@ -652,7 +662,7 @@ async def my_callback(result): dockey="test", ) docs.query("What is Frederick Bates's greatest accomplishment?") - assert any([x.name == "answer" for x in my_results]) + assert any(x.name == "answer" for x in my_results) assert len(my_results) > 1 @@ -687,7 +697,7 @@ def test_custom_embedding(): class MyEmbeds(EmbeddingModel): name: str = "my_embed" - async def embed_documents(self, client, texts): + async def embed_documents(self, client, texts): # noqa: ARG002 return [[1, 2, 3] for _ in texts] docs = Docs( @@ -774,7 +784,7 @@ def test_custom_llm(): class MyLLM(LLMModel): name: str = "myllm" - async def acomplete(self, client, prompt): + async def acomplete(self, client, prompt): # noqa: ARG002 assert client is None return "Echo" @@ -792,7 +802,7 @@ def test_custom_llm_stream(): class MyLLM(LLMModel): name: str = "myllm" - async def acomplete_iter(self, client, prompt): + async def acomplete_iter(self, client, prompt): # noqa: ARG002 assert client is None yield "Echo" @@ -803,7 +813,8 @@ async def acomplete_iter(self, client, prompt): dockey="test", ) evidence = docs.get_evidence( - Answer(question="Echo"), get_callbacks=lambda x: [lambda y: print(y, end="")] + Answer(question="Echo"), + get_callbacks=lambda x: [lambda y: print(y, end="")], # noqa: ARG005 ) assert "Echo" in evidence.context @@ -827,7 +838,7 @@ def test_langchain_llm(): docs.get_evidence( Answer(question="What is Frederick Bates's greatest accomplishment?"), - get_callbacks=lambda x: [lambda y: print(y, end="")], + get_callbacks=lambda x: [lambda y: print(y, end="")], # noqa: ARG005 ) assert docs.llm_model.llm_type == "chat" @@ -847,7 +858,7 @@ def test_langchain_llm(): ) docs.get_evidence( Answer(question="What is Frederick Bates's greatest accomplishment?"), - get_callbacks=lambda x: [lambda y: print(y, end="")], + get_callbacks=lambda x: [lambda y: print(y, end="")], # noqa: ARG005 ) assert docs.summary_llm_model.llm_type == "completion" # type: ignore[union-attr] @@ -859,14 +870,14 @@ def test_langchain_llm(): # now make sure we can pickle it docs_pickle = pickle.dumps(docs) - docs2 = pickle.loads(docs_pickle) + docs2 = pickle.loads(docs_pickle) # noqa: S301 assert docs2._client is None assert docs2.llm == "babbage-002" docs2.set_client(OpenAI(model="babbage-002")) assert docs2.summary_llm == "babbage-002" docs2.get_evidence( Answer(question="What is Frederick Bates's greatest accomplishment?"), - get_callbacks=lambda x: [lambda y: print(y)], + get_callbacks=lambda x: [lambda y: print(y)], # noqa: ARG005 ) @@ -911,19 +922,19 @@ async def test_langchain_vector_store(self): try: index = LangchainVectorStore() index.add_texts_and_embeddings(some_texts) - raise "Failed to check for builder" # type: ignore[misc] + raise "Failed to check for builder" # type: ignore[misc] # noqa: B016 except ValueError: pass try: - index = LangchainVectorStore(store_builder=lambda x: None) - raise "Failed to count arguments" # type: ignore[misc] + index = LangchainVectorStore(store_builder=lambda x: None) # noqa: ARG005 + raise "Failed to count arguments" # type: ignore[misc] # noqa: B016 except ValueError: pass try: index = LangchainVectorStore(store_builder="foo") - raise "Failed to check if builder is callable" # type: ignore[misc] + raise "Failed to check if builder is callable" # type: ignore[misc] # noqa: B016 except ValueError: pass @@ -981,7 +992,7 @@ async def test_langchain_vector_store(self): # make sure we can pickle it docs_pickle = pickle.dumps(docs) - pickle.loads(docs_pickle) + pickle.loads(docs_pickle) # noqa: S301 # will not work at this point - have to reset index @@ -1019,12 +1030,14 @@ def test_docs_pickle() -> None: # 1. Fill out docs with tempfile.NamedTemporaryFile(mode="r+", encoding="utf-8", suffix=".html") as f: # get front page of wikipedia - r = requests.get("https://en.wikipedia.org/wiki/Take_Your_Dog_to_Work_Day") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/Take_Your_Dog_to_Work_Day" + ) r.raise_for_status() f.write(r.text) docs = Docs( llm_model=OpenAILLMModel( - config=dict(temperature=0.0, model="gpt-3.5-turbo") + config={"temperature": 0.0, "model": "gpt-3.5-turbo"} ) ) assert docs._client is not None @@ -1034,7 +1047,7 @@ def test_docs_pickle() -> None: # 2. Pickle and unpickle, checking unpickled is in-tact docs_pickle = pickle.dumps(docs) - docs2 = pickle.loads(docs_pickle) + docs2 = pickle.loads(docs_pickle) # noqa: S301 with pytest.raises(ValueError, match="forget to set it after pickling"): docs2.query("What date is bring your dog to work in the US?") docs2.set_client() @@ -1078,7 +1091,9 @@ def test_bad_context(): doc_path = "example.html" with open(doc_path, "w", encoding="utf-8") as f: # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)" + ) f.write(r.text) docs = Docs() docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") # type: ignore[arg-type] @@ -1091,13 +1106,15 @@ def test_repeat_keys(): doc_path = "example.txt" with open(doc_path, "w", encoding="utf-8") as f: # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)" + ) f.write(r.text) docs = Docs( - llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="babbage-002")) + llm_model=OpenAILLMModel(config={"temperature": 0.0, "model": "babbage-002"}) ) docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") # type: ignore[arg-type] - try: + try: # noqa: SIM105 docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") # type: ignore[arg-type] except ValueError: pass @@ -1124,7 +1141,7 @@ def test_repeat_keys(): def test_pdf_reader(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") - docs = Docs(llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="gpt-4"))) + docs = Docs(llm_model=OpenAILLMModel(config={"temperature": 0.0, "model": "gpt-4"})) docs.add(doc_path, "Wellawatte et al, XAI Review, 2023") # type: ignore[arg-type] answer = docs.query("Are counterfactuals actionable? [yes/no]") assert "yes" in answer.answer or "Yes" in answer.answer @@ -1143,7 +1160,9 @@ def test_fileio_reader_pdf(): def test_fileio_reader_txt(): # can't use curie, because it has trouble with parsed HTML docs = Docs() - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)" + ) if r.status_code != 200: raise ValueError("Could not download wikipedia page") docs.add_file( @@ -1182,7 +1201,9 @@ def test_prompt_length(): doc_path = "example.txt" with open(doc_path, "w", encoding="utf-8") as f: # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)" + ) f.write(r.text) docs = Docs() docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") # type: ignore[arg-type] @@ -1193,7 +1214,7 @@ def test_code(): # load this script doc_path = os.path.abspath(__file__) docs = Docs( - llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="babbage-002")) + llm_model=OpenAILLMModel(config={"temperature": 0.0, "model": "babbage-002"}) ) docs.add(doc_path, "test_paperqa.py", docname="test_paperqa.py", disable_check=True) # type: ignore[arg-type] assert len(docs.docs) == 1 @@ -1204,24 +1225,28 @@ def test_citation(): doc_path = "example.txt" with open(doc_path, "w", encoding="utf-8") as f: # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)" + ) f.write(r.text) docs = Docs() docs.add(doc_path) # type: ignore[arg-type] - assert ( - list(docs.docs.values())[0].docname == "Wikipedia2024" - or list(docs.docs.values())[0].docname == "Frederick2024" - or list(docs.docs.values())[0].docname == "Wikipedia" - or list(docs.docs.values())[0].docname == "Frederick" + assert next(iter(docs.docs.values())).docname in ( + "Wikipedia2024", + "Frederick2024", + "Wikipedia", + "Frederick", ) def test_dockey_filter(): - """Test that we can filter evidence with dockeys""" + """Test that we can filter evidence with dockeys.""" doc_path = "example2.txt" with open(doc_path, "w", encoding="utf-8") as f: # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)" + ) f.write(r.text) docs = Docs() docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") # type: ignore[arg-type] @@ -1235,11 +1260,13 @@ def test_dockey_filter(): def test_dockey_delete(): - """Test that we can filter evidence with dockeys""" + """Test that we can filter evidence with dockeys.""" doc_path = "example2.txt" with open(doc_path, "w", encoding="utf-8") as f: # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)" + ) f.write(r.text) docs = Docs() docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") # type: ignore[arg-type] @@ -1252,7 +1279,7 @@ def test_dockey_delete(): answer = docs.get_evidence( answer, max_sources=25, k=30 ) # we just have a lot so we get both docs - keys = set([c.text.doc.dockey for c in answer.contexts]) + keys = {c.text.doc.dockey for c in answer.contexts} assert len(keys) == 2 assert len(docs.docs) == 2 @@ -1264,16 +1291,18 @@ def test_dockey_delete(): assert len(docs.docs) == 1 assert len(docs.deleted_dockeys) == 1 answer = docs.get_evidence(answer, max_sources=25, k=30) - keys = set([c.text.doc.dockey for c in answer.contexts]) + keys = {c.text.doc.dockey for c in answer.contexts} assert len(keys) == 1 def test_query_filter(): - """Test that we can filter evidence with in query""" + """Test that we can filter evidence with in query.""" doc_path = "example2.txt" with open(doc_path, "w", encoding="utf-8") as f: # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)" + ) f.write(r.text) docs = Docs() docs.add( @@ -1304,7 +1333,7 @@ def test_too_much_evidence(): doc_path = "example2.txt" with open(doc_path, "w", encoding="utf-8") as f: # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Barack_Obama") + r = requests.get("https://en.wikipedia.org/wiki/Barack_Obama") # noqa: S113 f.write(r.text) docs = Docs(llm="gpt-3.5-turbo", summary_llm="gpt-3.5-turbo") docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") # type: ignore[arg-type] @@ -1333,7 +1362,9 @@ def test_custom_prompts(): doc_path = "example.html" with open(doc_path, "w", encoding="utf-8") as f: # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)" + ) f.write(r.text) docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") # type: ignore[arg-type] answer = docs.query("What country is Frederick Bates from?") @@ -1341,14 +1372,16 @@ def test_custom_prompts(): def test_pre_prompt(): - pre = "Provide context you have memorized " "that could help answer '{question}'. " + pre = "Provide context you have memorized that could help answer '{question}'. " docs = Docs(prompts=PromptCollection(pre=pre)) doc_path = "example.txt" with open(doc_path, "w", encoding="utf-8") as f: # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)" + ) f.write(r.text) docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") # type: ignore[arg-type] docs.query("What country is Bates from?") @@ -1368,7 +1401,9 @@ def test_post_prompt(): doc_path = "example.txt" with open(doc_path, "w", encoding="utf-8") as f: # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") + r = requests.get( # noqa: S113 + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)" + ) f.write(r.text) docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") # type: ignore[arg-type] docs.query("What country is Bates from?") @@ -1412,7 +1447,7 @@ def disabled_test_memory(): def test_add_texts(): - llm_config = dict(temperature=0.1, model="babbage-02") + llm_config = {"temperature": 0.1, "model": "babbage-02"} docs = Docs(llm_model=OpenAILLMModel(config=llm_config)) docs.add_url( "https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day", @@ -1424,7 +1459,7 @@ def test_add_texts(): texts = [Text(**dict(t)) for t in docs.texts] for t in texts: t.embedding = None - docs2.add_texts(texts, list(docs.docs.values())[0]) + docs2.add_texts(texts, next(iter(docs.docs.values()))) for t1, t2 in zip(docs2.texts, docs.texts): assert t1.text == t2.text @@ -1432,7 +1467,7 @@ def test_add_texts(): docs2._build_texts_index() # now do it again to test after text index is already built - llm_config = dict(temperature=0.1, model="babbage-02") + llm_config = {"temperature": 0.1, "model": "babbage-02"} docs = Docs(llm_model=OpenAILLMModel(config=llm_config)) docs.add_url( "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", @@ -1443,7 +1478,7 @@ def test_add_texts(): texts = [Text(**dict(t)) for t in docs.texts] for t in texts: t.embedding = None - docs2.add_texts(texts, list(docs.docs.values())[0]) + docs2.add_texts(texts, next(iter(docs.docs.values()))) assert len(docs2.docs) == 2 docs2.query("What country was Bates Born in?")