From 6cfcc8dd0614802d6d16738c00ca648c8d826f34 Mon Sep 17 00:00:00 2001 From: yaakov Date: Wed, 27 Dec 2023 11:47:06 +0100 Subject: [PATCH] fix grpc query response --- pinecone/grpc/utils.py | 69 +++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 35 deletions(-) diff --git a/pinecone/grpc/utils.py b/pinecone/grpc/utils.py index 4e8057ec..5e948f3e 100644 --- a/pinecone/grpc/utils.py +++ b/pinecone/grpc/utils.py @@ -35,7 +35,7 @@ def parse_sparse_values(sparse_values: dict): return ( SparseValues(indices=sparse_values["indices"], values=sparse_values["values"]) if sparse_values - else SparseValues(indices=[], values=[]) + else None ) @@ -49,7 +49,7 @@ def parse_fetch_response(response: dict): id=vec["id"], values=vec["values"], sparse_values=parse_sparse_values(vec.get("sparseValues")), - metadata=vec.get("metadata", None), + metadata=vec.get("metadata"), _check_type=False, ) vd[id] = v_obj @@ -58,49 +58,48 @@ def parse_fetch_response(response: dict): def parse_query_response(response: dict, unary_query: bool, _check_type: bool = False): - res = [] - - # TODO: consider deleting this deprecated case - for match in response.get("results", []): - namespace = match.get("namespace", "") + if unary_query: m = [] - if "matches" in match: - for item in match["matches"]: - sc = ScoredVector( - id=item["id"], - score=item.get("score", 0.0), - values=item.get("values", []), - sparse_values=parse_sparse_values(item.get("sparseValues")), - metadata=item.get("metadata", {}), - ) - m.append(sc) - res.append(SingleQueryResults(matches=m, namespace=namespace)) + for item in response.get("matches"): + sc = ScoredVector( + id=item["id"], + score=item.get("score"), + values=item.get("values"), + sparse_values=parse_sparse_values(item.get("sparseValues")), + metadata=item.get("metadata"), + _check_type=_check_type, + ) + m.append(sc) - m = [] - for item in response.get("matches", []): - sc = ScoredVector( - id=item["id"], - score=item.get("score", 0.0), - values=item.get("values", []), - sparse_values=parse_sparse_values(item.get("sparseValues")), - metadata=item.get("metadata", {}), - _check_type=_check_type, - ) - m.append(sc) - - if unary_query: namespace = response.get("namespace", "") matches = m - results = [] + results = None else: - namespace = "" - matches = [] + # TODO: consider deleting this deprecated case + res = [] + for match in response.get("results", []): + namespace = match.get("namespace", "") + m = [] + if "matches" in match: + for item in match["matches"]: + sc = ScoredVector( + id=item["id"], + score=item.get("score"), + values=item.get("values"), + sparse_values=parse_sparse_values(item.get("sparseValues")), + metadata=item.get("metadata"), + _check_type=_check_type, + ) + m.append(sc) + res.append(SingleQueryResults(matches=m, namespace=namespace)) + namespace = None + matches = None results = res kw = QueryResponseKwargs(_check_type, namespace, matches, results) kw_dict = kw._asdict() kw_dict["_check_type"] = kw.check_type - return QueryResponse(**kw._asdict()) + return QueryResponse(**kw_dict) def parse_stats_response(response: dict):