diff --git a/paperqa/types.py b/paperqa/types.py index ef53df6e5..cd3306bae 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Collection from typing import Any, Callable from uuid import UUID, uuid4 @@ -233,6 +234,15 @@ def add_tokens(self, result: LLMResult): self.token_counts[result.model][0] += result.prompt_count self.token_counts[result.model][1] += result.completion_count + def get_unique_docs_from_contexts( + self, score_threshold: int = 0 + ) -> Collection[Doc]: + """Parse contexts for docs with scores above the input threshold.""" + return { + c.text.doc + for c in filter(lambda x: x.score >= score_threshold, self.contexts) + } + class ChunkMetadata(BaseModel): """Metadata for chunking algorithm."""