Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add array_contains and array_overlaps operators for pgvector metadata filtering. #1352

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
from itertools import chain
from typing import Any, Dict, List, Literal, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple

from haystack.errors import FilterError
from pandas import DataFrame
Expand Down Expand Up @@ -99,13 +99,13 @@ def _parse_comparison_condition(condition: Dict[str, Any]) -> Tuple[str, List[An
field = f"({field})::jsonb"

if field.startswith("meta."):
field = _treat_meta_field(field, value)
field = _treat_meta_field(field, value, operator)

field, value = COMPARISON_OPERATORS[operator](field, value)
return field, [value]


def _treat_meta_field(field: str, value: Any) -> str:
def _treat_meta_field(field: str, value: Any, operator: Optional[str] = None) -> str:
"""
Internal method that modifies the field str
to make the meta JSONB field queryable.
Expand All @@ -116,10 +116,17 @@ def _treat_meta_field(field: str, value: Any) -> str:

>>> _treat_meta_field(field="meta.name", value="my_name")
"meta->>'name'"
"""

# use the ->> operator to access keys in the meta JSONB field
>>> _treat_meta_field(field="meta.tags", value=["tag1", "tag2"], operator="array_contains")
"meta->'tags'"
"""
field_name = field.split(".", 1)[-1]

# For array operations, we need to use the -> operator
if operator and operator.startswith("array_"):
return f"meta->'{field_name}'"

# use the ->> operator to access keys in the meta JSONB field as text
field = f"meta->>'{field_name}'"

# meta fields are stored as strings in the JSONB field,
Expand Down Expand Up @@ -246,6 +253,24 @@ def _not_like(field: str, value: Any) -> Tuple[str, Any]:
return f"{field} NOT LIKE %s", value


def _array_contains(field: str, value: Any) -> Tuple[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'array_contains' operator"
raise FilterError(msg)

# @> expects value to be a JSONB type
return f"{field} @> %s", Jsonb(value)


def _array_overlaps(field: str, value: Any) -> Tuple[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'array_overlaps' operator"
raise FilterError(msg)

# ?| expects value to be a text array
return f"{field} ?| %s", value


COMPARISON_OPERATORS = {
"==": _equal,
"!=": _not_equal,
Expand All @@ -257,4 +282,6 @@ def _not_like(field: str, value: Any) -> Tuple[str, Any]:
"not in": _not_in,
"like": _like,
"not like": _not_like,
"array_contains": _array_contains,
"array_overlaps": _array_overlaps,
}
49 changes: 49 additions & 0 deletions integrations/pgvector/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,28 @@ def test_not_like_operator_nb_chars(self, document_store, filterable_docs):
],
)

def test_array_contains_filter(self, document_store):
docs = [
Document(content="doc1", meta={"tags": ["tag1", "tag2"]}),
Document(content="doc2", meta={"tags": ["tag2", "tag3"]}),
Document(content="doc3", meta={"tags": ["tag1", "tag3"]}),
]
document_store.write_documents(docs)
filters = {"field": "meta.tags", "operator": "array_contains", "value": ["tag1", "tag2"]}
result = document_store.filter_documents(filters=filters)
self.assert_documents_are_equal(result, [docs[0]])

def test_array_overlaps_filter(self, document_store):
docs = [
Document(content="doc1", meta={"tags": ["tag1", "tag2"]}),
Document(content="doc2", meta={"tags": ["tag2", "tag3"]}),
Document(content="doc3", meta={"tags": ["tag4"]}),
]
document_store.write_documents(docs)
filters = {"field": "meta.tags", "operator": "array_overlaps", "value": ["tag1", "tag3"]}
result = document_store.filter_documents(filters=filters)
self.assert_documents_are_equal(result, [docs[0], docs[1]])

def test_complex_filter(self, document_store, filterable_docs):
document_store.write_documents(filterable_docs)
filters = {
Expand Down Expand Up @@ -141,12 +163,39 @@ def test_treat_meta_field():
assert _treat_meta_field(field="meta.bool", value=True) == "(meta->>'bool')::boolean"
assert _treat_meta_field(field="meta.bool", value=[True, False, True]) == "(meta->>'bool')::boolean"

# Array operators should keep JSON type
assert _treat_meta_field(field="meta.tags", value=["a", "b"], operator="array_contains") == "meta->'tags'"
assert _treat_meta_field(field="meta.tags", value=["a", "b"], operator="array_overlaps") == "meta->'tags'"

# do not cast the field if its value is not one of the known types, an empty list or None
assert _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) == "meta->>'other'"
assert _treat_meta_field(field="meta.empty_list", value=[]) == "meta->>'empty_list'"
assert _treat_meta_field(field="meta.name", value=None) == "meta->>'name'"


def test_array_contains_operator():
condition = {"field": "meta.tags", "operator": "array_contains", "value": ["tag1", "tag2"]}
field, values = _parse_comparison_condition(condition)
assert field == "meta->'tags' @> %s"
assert values[0].obj == Jsonb(["tag1", "tag2"]).obj


def test_array_overlaps_operator():
condition = {"field": "meta.tags", "operator": "array_overlaps", "value": ["tag1", "tag2"]}
field, values = _parse_comparison_condition(condition)
assert field == "meta->'tags' ?| %s"
assert isinstance(values[0], list)
assert values[0] == ["tag1", "tag2"]


def test_array_operators_require_list():
with pytest.raises(FilterError, match="must be a list when using 'array_contains' operator"):
_parse_comparison_condition({"field": "meta.tags", "operator": "array_contains", "value": "not_a_list"})

with pytest.raises(FilterError, match="must be a list when using 'array_overlaps' operator"):
_parse_comparison_condition({"field": "meta.tags", "operator": "array_overlaps", "value": "not_a_list"})


def test_comparison_condition_dataframe_jsonb_conversion():
dataframe = DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
condition = {"field": "meta.df", "operator": "==", "value": dataframe}
Expand Down