Skip to content

Commit

Permalink
Fix sparse query types
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Dec 17, 2024
1 parent df0e8e4 commit c618fff
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 28 deletions.
30 changes: 4 additions & 26 deletions pinecone/data/request_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,6 @@
logger = logging.getLogger(__name__)


def parse_sparse_values_arg(
sparse_values: Optional[Union[SparseValues, SparseVectorTypedDict]],
) -> Optional[SparseValues]:
if sparse_values is None:
return None

if isinstance(sparse_values, SparseValues):
return sparse_values

if (
not isinstance(sparse_values, dict)
or "indices" not in sparse_values
or "values" not in sparse_values
):
raise ValueError(
"Invalid sparse values argument. Expected a dict of: {'indices': List[int], 'values': List[float]}."
f"Received: {sparse_values}"
)

return SparseValues(indices=sparse_values["indices"], values=sparse_values["values"])


def non_openapi_kwargs(kwargs):
return {k: v for k, v in kwargs.items() if k not in OPENAPI_ENDPOINT_PARAMS}

Expand All @@ -67,7 +45,7 @@ def query_request(
if vector is not None and id is not None:
raise ValueError("Cannot specify both `id` and `vector`")

sparse_vector = SparseValuesFactory.build(sparse_vector)
sparse_vector_normalized = SparseValuesFactory.build(sparse_vector)
args_dict = parse_non_empty_args(
[
("vector", vector),
Expand All @@ -78,7 +56,7 @@ def query_request(
("filter", filter),
("include_values", include_values),
("include_metadata", include_metadata),
("sparse_vector", sparse_vector),
("sparse_vector", sparse_vector_normalized),
]
)

Expand Down Expand Up @@ -131,13 +109,13 @@ def update_request(
**kwargs,
) -> UpdateRequest:
_check_type = kwargs.pop("_check_type", False)
sparse_values = parse_sparse_values_arg(sparse_values)
sparse_values_normalized = SparseValuesFactory.build(sparse_values)
args_dict = parse_non_empty_args(
[
("values", values),
("set_metadata", set_metadata),
("namespace", namespace),
("sparse_values", sparse_values),
("sparse_values", sparse_values_normalized),
]
)

Expand Down
5 changes: 3 additions & 2 deletions pinecone/data/sparse_values_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Union, Dict, Optional
from typing import Union, Optional

from ..utils import convert_to_list

Expand All @@ -10,6 +10,7 @@
)

from .dataclasses import SparseValues
from .types import SparseVectorTypedDict
from pinecone.core.openapi.db_data.models import SparseValues as OpenApiSparseValues


Expand All @@ -18,7 +19,7 @@ class SparseValuesFactory:

@staticmethod
def build(
input: Union[Dict, Optional[SparseValues], OpenApiSparseValues],
input: Optional[Union[SparseValues, OpenApiSparseValues, SparseVectorTypedDict]],
) -> Optional[OpenApiSparseValues]:
if input is None:
return input
Expand Down

0 comments on commit c618fff

Please sign in to comment.