Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework search query building #278

Merged
merged 4 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 44 additions & 29 deletions services/search/api.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
"""
Brief explanation how full text search is implemented in the smbacked.
Brief explanation how full text search is implemented in the smbackend.
- Currently search is performed to following models, Unit, Service,
munigeo_Address, munigeo_Administrative_division.
- For every model that is include in the search a search column is added
- For every model that is included in the search a search column is added
for every language of type SearchVector. These are also defined as a Gindex.
The models that are searched also implements a function called get_search_column_indexing
where the name, configuration(language) and weight of the columns that will be indexed
are defined. This function is used by the indexing script and signals when
the search_column is populated.
- A view called search_view is created and it contains the search_columns of the models
and a couple auxilary columns: id. type_name and name. This view is created by a
and a couple auxiliary columns: id. type_name and name. This view is created by a
raw SQL migration 008X_create_search_view.py.
- The search if performed by quering the views search_columns.
- The search if performed by querying the views search_columns.
- For models included in the search a post_save signal is connected and the
search_column is updated when they are saved.
- The search_columns can be manually updated with the index_search_columns
Expand Down Expand Up @@ -228,6 +228,35 @@ def to_representation(self, obj):
return representation


def build_search_query(query: str):
result = ""
query = query.strip(" |&,")
query = query.replace("\\", "\\\\")
or_operands = re.split(r"(?:\s*\|+\s*)+", query)

for or_operand in or_operands:
# Whitespace, comma and ampersand are all considered AND operators
and_operands = re.split(r"[\s,&]+", or_operand)
expression = ""
for and_operand in and_operands:
if re.fullmatch(r"'+", and_operand):
# Skip any operands that are just repeating single-quotes
continue
if expression:
expression += f" & {and_operand}:*"
else:
expression = f"{and_operand}:*"

if not expression:
continue
if result:
result += f" | {expression}"
else:
result = expression

return result


@extend_schema(
parameters=[
OpenApiParameter(
Expand Down Expand Up @@ -274,7 +303,7 @@ def to_representation(self, obj):
OpenApiParameter(
name="use_websearch",
location=OpenApiParameter.QUERY,
description="Use websearch_to_tsquery instead of to_tsquery if exlusion rules are defined for the search.",
description="Use websearch_to_tsquery instead of to_tsquery if exclusion rules are defined for the search.",
required=False,
type=bool,
),
Expand Down Expand Up @@ -381,9 +410,9 @@ def get(self, request):
if not q_val:
raise ParseError("Supply search terms with 'q=' ' or input=' '")

if not re.match(r"^[\w\såäö.'+&|-]+$", q_val):
if not re.match(r"^[\w\såäö.,'+&|-]+$", q_val):
raise ParseError(
"Invalid search terms, only letters, numbers, spaces and .'+-&| allowed."
"Invalid search terms, only letters, numbers, spaces and .,'+-&| allowed."
)

types_str = ",".join([elem for elem in QUERY_PARAM_TYPE_NAMES])
Expand Down Expand Up @@ -484,34 +513,20 @@ def get(self, request):
)

config_language = LANGUAGES[language_short]
search_query_str = None # Used in the raw sql
# Replace multiple consecutive vertical bars with a single vertical bar to be used as an OR operator.
q_val = re.sub(r"\|+", "|", q_val)
# Remove vertical bars that are not between words to avoid errors in the query.
q_val = re.sub(r"(?<!\w)\|+|\|+(?!\w)", "", q_val)
# Build conditional query string that is used in the SQL query.
# split by "," or whitespace
q_vals = re.split(r",\s+|\s+", q_val)
for q in q_vals:
if search_query_str:
# if ends with "|"" make it a or
if q[-1] == "|":
search_query_str += f"| {q[:-1]}:*"
# else make it an and.
else:
search_query_str += f"& {q}:*"
else:
search_query_str = f"{q}:*"

if has_exclusion_word_in_query(q_vals, language_short):
all_operands = re.split(r"[\s|&,]+", q_val)
if has_exclusion_word_in_query(all_operands, language_short):
return Response(
f"Search query {q_vals} would return too many results",
f"Search query {q_val} would return too many results",
status=status.HTTP_400_BAD_REQUEST,
)
search_query_str = build_search_query(q_val)

search_fn = "to_tsquery"
if use_websearch:
exclusions = self.get_search_exclusions(q)
# NOTE: The check is done on the last operand only on purpose.
# This is a fix for some old edge case.
exclusions = self.get_search_exclusions(all_operands[-1])
if exclusions:
search_fn = "websearch_to_tsquery"
search_query_str += f" {exclusions}"
Expand All @@ -531,7 +546,7 @@ def get(self, request):
try:
cursor.execute(sql, [search_query_str])
except Exception as e:
logger.error(f"Error in search query: {e}")
logger.error(f"Error in search query: {e}", exc_info=e)
raise ParseError("Search query failed.")
# Note, fetchall() consumes the results and once called returns None.
all_results = cursor.fetchall()
Expand Down
183 changes: 115 additions & 68 deletions services/search/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
from rest_framework.reverse import reverse

from services.search.api import build_search_query


@pytest.mark.django_db
def test_search(
Expand Down Expand Up @@ -179,92 +181,137 @@ def test_search_service_order(api_client, units, services):
assert results[2]["unit_count"]["total"] == 0


@pytest.mark.parametrize(
"query",
[
"halli|museo", # Test |
"halli&museo", # Test &
"linja-auto", # Test -
"Keskustakirjasto Oodi", # Test whitespace
"Keskustakirjasto+Oodi", # Test +
# Test ääkköset
"lääkäri",
"röntgen",
"åbo",
"123", # Test number
"halli.museo", # Test .
"halli's", # Test '
"halli,museo", # Test ,
],
)
@pytest.mark.django_db
def test_search_input_query_validation(api_client):
# Test that | is allowed in query
url = reverse("search") + "?q=halli|museo"
response = api_client.get(url)
assert response.status_code == 200

# Test that & is allowed in query
url = reverse("search") + "?q=halli&museo"
response = api_client.get(url)
assert response.status_code == 200

# Test that - is allowed in query
url = reverse("search") + "?q=linja-auto"
response = api_client.get(url)
assert response.status_code == 200

# Test that " " is allowed in query
url = reverse("search") + "?q=Keskustakirjasto Oodi"
response = api_client.get(url)
assert response.status_code == 200

# Test that + is allowed in query
url = reverse("search") + "?q=Keskustakirjasto+Oodi"
response = api_client.get(url)
def test_search_input_query_valid_characters(api_client, query):
url = reverse("search")
response = api_client.get(url, {"q": query})
assert response.status_code == 200

# Test that "ääkköset" are allowed in query
url = reverse("search") + "?q=lääkäri"
response = api_client.get(url)
assert response.status_code == 200
url = reverse("search") + "?q=röntgen"
response = api_client.get(url)
assert response.status_code == 200
url = reverse("search") + "?q=åbo"
response = api_client.get(url)
assert response.status_code == 200

# Test that numbers are allowed in query
url = reverse("search") + "?q=123"
response = api_client.get(url)
assert response.status_code == 200

# Test that . is allowed in query
url = reverse("search") + "?q=halli.museo"
response = api_client.get(url)
assert response.status_code == 200

# Test that ' is allowed in query
url = reverse("search") + "?q=halli's"
response = api_client.get(url)
assert response.status_code == 200

# Test that special characters are not allowed in query
url = reverse("search") + "?q=halli("
response = api_client.get(url)
@pytest.mark.django_db
def test_search_input_query_disallowed_characters(api_client):
url = reverse("search")
response = api_client.get(url, {"q": "halli("})
assert response.status_code == 400
assert (
response.json()["detail"]
== "Invalid search terms, only letters, numbers, spaces and .'+-&| allowed."
== "Invalid search terms, only letters, numbers, spaces and .,'+-&| allowed."
)


@pytest.mark.django_db
def test_search_with_vertical_bar_in_query(api_client, units):
# Test that a single vertical bar in query is treated as OR operator
url = reverse("search") + "?q=terveysasema|museo&type=unit"
# Test that a vertical bars that are not between search terms do not cause an error
url = reverse("search") + "?q=|terveysasema||''||'"
response = api_client.get(url)
assert response.status_code == 200
assert len(response.json()["results"]) == 2
assert response.json()["results"][0]["name"]["fi"] == "Terveysasema"
assert response.json()["results"][1]["name"]["fi"] == "Biologinen museo"
assert response.status_code == 200, f"{response} {response.json()}"

# Test that multiple vertical bars in query are treated as OR operators
url = reverse("search") + "?q=terveysasema||museo&type=unit"
response = api_client.get(url)
assert response.status_code == 200

@pytest.mark.parametrize(
"query,expected",
[
# Single-operand expression
("a", "a:*"),
("|a|", "a:*"),
("&a&", "a:*"),
(" a ", "a:*"),
("& |a || &|& &&", "a:*"),
# Two-operand expressions with AND operator
("a b", "a:* & b:*"),
("a,b", "a:* & b:*"),
("a, b", "a:* & b:*"),
("a,,, , ,, , , , , , ,, , b", "a:* & b:*"),
("a & b", "a:* & b:*"),
("a& b", "a:* & b:*"),
("a &b", "a:* & b:*"),
("a&b", "a:* & b:*"),
("a&&&&&&&&&&&&&b", "a:* & b:*"),
("a , &&&&&&&, & ,, & & & & &&&&& , , ,,,, b", "a:* & b:*"),
# Two-operand expressions with OR operator
("a | b", "a:* | b:*"),
("a | b", "a:* | b:*"),
("a| b", "a:* | b:*"),
("a |b", "a:* | b:*"),
("a|b", "a:* | b:*"),
("a|||||||||||||b", "a:* | b:*"),
("a ||| ||| || | ||| b", "a:* | b:*"),
# >=3 operands
("a | b | c", "a:* | b:* | c:*"),
("a, b, c", "a:* & b:* & c:*"),
("a & b c, d", "a:* & b:* & c:* & d:*"),
# Mixed OR and AND operators
("a, b | c, d", "a:* & b:* | c:* & d:*"),
("a, &&& , & b || || |||| |c,,,, d", "a:* & b:* | c:* & d:*"),
# Expression with repeating single-quotes
("','','''',a,b'c,d''e,f'''g,','','''", "a:* & b'c:* & d''e:* & f'''g:*"),
],
)
def test_build_search_query(query, expected):
assert build_search_query(query) == expected


@pytest.mark.parametrize(
"query",
[
"palloilu, halli",
"palloilu,halli",
"palloilu, halli",
"palloilu,,,,,,halli",
"palloilu halli",
"palloilu halli",
",palloilu,halli",
"palloilu,halli,,",
],
)
@pytest.mark.django_db
def test_search_input_and_operator(api_client, units, query):
url = reverse("search")
response = api_client.get(url, {"q": query, "type": "unit"})
assert response.status_code == 200, f"{response} {response.json()}"
assert len(response.json()["results"]) == 1
assert response.json()["results"][0]["name"]["fi"] == "Palloiluhalli"


@pytest.mark.parametrize(
"query",
[
"terveysasema|museo",
"terveysasema | museo",
"terveysasema |museo",
"terveysasema| museo",
" terveysasema | museo ",
"terveysasema||museo",
"terveysasema|||||||||museo",
"|||terveysasema|museo||",
],
)
@pytest.mark.django_db
def test_search_input_or_operator(api_client, units, query):
url = reverse("search")
response = api_client.get(url, {"q": query, "type": "unit"})
assert response.status_code == 200, f"{response} {response.json()}"
assert len(response.json()["results"]) == 2
assert response.json()["results"][0]["name"]["fi"] == "Terveysasema"
assert response.json()["results"][1]["name"]["fi"] == "Biologinen museo"

# Test that a vertical bars that are not between search terms do not cause an error
url = reverse("search") + "?q=|terveysasema||''||'"
response = api_client.get(url)
assert response.status_code == 200


@pytest.mark.django_db
def test_search_with_bbox_parameter(api_client, units):
Expand Down