Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
chore: resolve typings
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Feb 19, 2024
1 parent 3eb2b1e commit e027521
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 25 deletions.
25 changes: 13 additions & 12 deletions src/canopy/knowledge_base/qdrant/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def encoded_docs_to_points(

if sparse_vector:
vector[SPARSE_VECTOR_NAME] = models.SparseVector(
indices=sparse_vector["indices"], # type: ignore
values=sparse_vector["values"], # type: ignore
indices=sparse_vector["indices"],
values=sparse_vector["values"],
)

points.append(
Expand All @@ -64,8 +64,8 @@ def encoded_docs_to_points(
def scored_point_to_scored_doc(
scored_point: models.ScoredPoint,
) -> "KBDocChunkWithScore":
metadata: Dict[str, Any] = deepcopy(scored_point.payload) # type: ignore
_id = metadata.pop("chunk_id") # type: ignore
metadata: Dict[str, Any] = deepcopy(scored_point.payload or {})
_id = metadata.pop("chunk_id")
text = metadata.pop("text", "")
document_id = metadata.pop("document_id")
return KBDocChunkWithScore(
Expand All @@ -82,16 +82,17 @@ def kb_query_to_search_vector(
query: KBQuery,
) -> "Union[models.NamedVector, models.NamedSparseVector]":
# Use dense vector if available, otherwise use sparse vector
query_vector: Union[models.NamedSparseVector, models.NamedVector] = (
models.NamedVector(name=DENSE_VECTOR_NAME, vector=query.values)
if query.values is not None
else models.NamedSparseVector(
query_vector: Union[models.NamedVector, models.NamedSparseVector]
if query.values:
query_vector = models.NamedVector(name=DENSE_VECTOR_NAME, vector=query.values) # noqa: E501
elif query.sparse_values:
query_vector = models.NamedSparseVector(
name=SPARSE_VECTOR_NAME,
vector=models.SparseVector(
indices=query.sparse_values["indices"], # type: ignore
values=query.sparse_values["values"], # type: ignore
indices=query.sparse_values["indices"],
values=query.sparse_values["values"],
),
)
)

else:
raise ValueError("Query should have either dense or sparse vector.")
return query_vector
35 changes: 23 additions & 12 deletions src/canopy/knowledge_base/qdrant/qdrant_knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from qdrant_client import models as models
from qdrant_client.http.exceptions import UnexpectedResponse
from grpc import RpcError # type: ignore
from grpc import RpcError #type: ignore[import-untyped]
from tqdm import tqdm


Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
timeout: Optional[int] = None,
host: Optional[str] = None,
path: Optional[str] = None,
force_disable_check_same_thread: bool = False,
Expand Down Expand Up @@ -208,7 +208,10 @@ def verify_index_connection(self) -> None:
) from e

def query(
self, queries: List[Query], global_metadata_filter: Optional[dict] = None
self,
queries: List[Query],
global_metadata_filter: Optional[dict] = None,
namespace: Optional[str] = None,
) -> List[QueryResult]:
"""
Query the knowledge base to retrieve document chunks.
Expand All @@ -223,6 +226,7 @@ def query(
queries: A list of queries to run against the knowledge base.
global_metadata_filter: A payload filter to apply to all queries, in addition to any query-specific filters.
Reference: https://qdrant.tech/documentation/concepts/filtering/
namespace: This argument is not used by Qdrant.
Returns:
A list of QueryResult objects.
Expand Down Expand Up @@ -467,7 +471,8 @@ async def adelete(self, document_ids: List[str], namespace: str = "") -> None:
>>> kb = QdrantKnowledgeBase(collection_name="my_collection")
>>> await kb.adelete(document_ids=["doc1", "doc2"])
""" # noqa: E501
await self._async_client.delete( # type: ignore
# @sync_fallback will call the sync method if the async client is None
self._async_client and await self._async_client.delete(
self.collection_name,
points_selector=models.Filter(
must=[
Expand Down Expand Up @@ -693,13 +698,18 @@ async def _aquery_collection(
# Use dense vector if available, otherwise use sparse vector
query_vector = QdrantConverter.kb_query_to_search_vector(query)

results = await self._async_client.search( # type: ignore
self.collection_name,
query_vector=query_vector,
limit=top_k,
query_filter=metadata_filter,
with_payload=True,
**query_params,
# @sync_fallback will call the sync method if the async client is None
results = (
await self._async_client.search(
self.collection_name,
query_vector=query_vector,
limit=top_k,
query_filter=metadata_filter,
with_payload=True,
**query_params,
)
if self._async_client
else []
)
documents: List[KBDocChunkWithScore] = []
for result in results:
Expand Down Expand Up @@ -743,7 +753,8 @@ async def _aupsert_collection(
document_batch,
)

await self._async_client.upsert( # type: ignore
# @sync_fallback will call the sync method if the async client is None
self._async_client and await self._async_client.upsert(
collection_name=self.collection_name,
points=batch,
)
Expand Down
2 changes: 1 addition & 1 deletion src/canopy/knowledge_base/qdrant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def generate_clients(
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
timeout: Optional[int] = None,
host: Optional[str] = None,
path: Optional[str] = None,
force_disable_check_same_thread: bool = False,
Expand Down

0 comments on commit e027521

Please sign in to comment.