Skip to content

Commit

Permalink
ENH: Add custom tsquery from websearch function and related tests (#838)
Browse files Browse the repository at this point in the history
* Add custom tsquery from websearch function and related tests

* update tests and logic

* fix style issues

---------

Co-authored-by: James Kent <[email protected]>
  • Loading branch information
adelavega and jdkent authored Nov 11, 2024
1 parent efdf3d9 commit ea5cdc4
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 2 deletions.
10 changes: 8 additions & 2 deletions store/neurostore/resources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from flask import abort, request, current_app # jsonify
from flask.views import MethodView

from psycopg2 import errors

import sqlalchemy as sa
import sqlalchemy.sql.expression as sae
from sqlalchemy.orm import (
Expand All @@ -21,7 +23,7 @@

from ..core import cache
from ..database import db
from .utils import get_current_user
from .utils import get_current_user, validate_search_query, pubmed_to_tsquery
from ..models import (
StudysetStudy,
AnnotationAnalysis,
Expand Down Expand Up @@ -613,7 +615,11 @@ def search(self):
if s is not None and s.isdigit():
q = q.filter_by(pmid=s)
elif s is not None and self._fulltext_fields:
tsquery = sa.func.websearch_to_tsquery("english", s)
try:
validate_search_query(s)
except errors.SyntaxError as e:
abort(400, description=e.args[0])
tsquery = pubmed_to_tsquery(s)
q = q.filter(m._ts_vector.op("@@")(tsquery))

# Alternatively (or in addition), search on individual fields.
Expand Down
181 changes: 181 additions & 0 deletions store/neurostore/resources/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re

from connexion.context import context
from psycopg2 import errors

from .. import models
from .. import schemas
Expand Down Expand Up @@ -44,3 +45,183 @@ class ClassView(cls):
ClassView.__name__ = cls.__name__

return ClassView


def validate_search_query(query: str) -> bool:
"""
Validate a search query string.
Args:
query (str): The query string to validate.
Returns:
bool: True if the query is valid, False otherwise.
"""
# Check for valid parentheses
if not validate_parentheses(query):
raise errors.SyntaxError("Unmatched parentheses")

# Check for valid query end
if not validate_query_end(query):
raise errors.SyntaxError("Query cannot end with an operator")

return True


def validate_parentheses(query: str) -> bool:
"""
Validate the parentheses in a query string.
Args:
query (str): The query string to validate.
Returns:
bool: True if parentheses are valid, False otherwise.
"""
stack = []
for char in query:
if char == "(":
stack.append(char)
elif char == ")":
if not stack:
return False # Unmatched closing parenthesis
stack.pop()
return not stack # Ensure all opening parentheses are closed


def validate_query_end(query: str) -> bool:
"""Query should not end with an operator"""
operators = ("AND", "OR", "NOT")

if query.strip().split(" ")[-1] in operators:
return False
return True


def count_chars(target, query: str) -> int:
"""Count the number of chars in a query string.
Excluding those in quoted phrases."""
count = 0
in_quotes = False
for char in query:
if char == '"':
in_quotes = not in_quotes
if char == target and not in_quotes:
count += 1
return count


def pubmed_to_tsquery(query: str) -> str:
"""
Convert a PubMed-like search query to PostgreSQL tsquery format,
grouping both single-quoted and double-quoted text with the <-> operator
for proximity search.
Additionally, automatically adds & between non-explicitly connected terms
and handles NOT terms.
Args:
query (str): The search query.
Returns:
str: The PostgreSQL tsquery equivalent.
"""

query = query.upper() # Ensure uniformity

# Step 1: Split into tokens (preserving quoted phrases)
# Regex pattern: match quoted phrases or non-space sequences
tokens = re.findall(r'"[^"]*"|\'[^\']*\'|\S+', query)

# Step 2: Combine tokens in parantheses into single tokens
def combine_parentheses(tokens: list) -> list:
"""
Combine tokens within parentheses into a single token.
Args:
tokens (list): List of tokens to process.
Returns:
list: Processed list with tokens inside parentheses combined.
"""
combined_tokens = []
buffer = []
paren_count = 0
for token in tokens:
# If buffer is not empty, we are inside parentheses
if len(buffer) > 0:
buffer.append(token)

# Adjust the count of parentheses
paren_count += count_chars("(", token) - count_chars(")", token)

if paren_count < 1:
# Combine all tokens in parentheses
combined_tokens.append(" ".join(buffer))
buffer = [] # Clear the buffer
paren_count = 0

else:
n_paren = count_chars("(", token) - count_chars(")", token)
# If not in parentheses, but token contains opening parentheses
# Start capturing tokens inside parentheses
if token[0] == "(" and n_paren > 0:
paren_count += n_paren
buffer.append(token) # Start capturing tokens in parens
print(buffer)
else:
combined_tokens.append(token)

# If the list ends without a closing parenthesis (invalid input)
# append buffer contents (fallback)
if buffer:
combined_tokens.append(" ".join(buffer))

return combined_tokens

tokens = combine_parentheses(tokens)
print(tokens)
for i, token in enumerate(tokens):
if token[0] == "(" and token[-1] == ")":
# RECURSIVE: Process the contents of the parentheses
token_res = pubmed_to_tsquery(token[1:-1])
token = "(" + token_res + ")"
tokens[i] = token

# Step 4: Handle both single-quoted and double-quoted phrases,
# grouping them with <-> (proximity operator)
elif token[0] in ('"', "'"):
# Split quoted text into individual words and join with <-> for
# proximity search
words = re.findall(r"\w+", token)
tokens[i] = "<->".join(words)

# Step 3: Replace logical operators AND, OR, NOT
else:
if token == "AND":
tokens[i] = "&"
elif token == "OR":
tokens[i] = "|"
elif token == "NOT":
tokens[i] = "&!"

processed_tokens = []
last_token = None
for token in tokens:
# Step 5: Add & between consecutive terms that aren't already
# connected by an operator
stripped_token = token.strip()
if stripped_token not in ("&", "|", "!", "&!"):
stripped_token = re.sub(r"[\[\],;:!?@#]", "", stripped_token)
if stripped_token == "":
continue # Ignore empty tokens from splitting

if last_token and last_token not in ("&", "|", "!", "&!"):
if stripped_token not in ("&", "|", "!", "&!"):
# Insert an implicit AND (&) between two non-operator tokens
processed_tokens.append("&")

processed_tokens.append(stripped_token)
last_token = stripped_token

return " ".join(processed_tokens)
15 changes: 15 additions & 0 deletions store/neurostore/tests/api/test_query_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from ...models import Study
from ...schemas.data import StudysetSchema, StudySchema, AnalysisSchema, StringOrNested
from ..conftest import valid_queries, invalid_queries


@pytest.mark.parametrize("nested", ["true", "false"])
Expand Down Expand Up @@ -99,3 +100,17 @@ def test_multiword_queries(auth_client, ingest_neurosynth, session):

multi_word_search = auth_client.get(f"/api/studies/?search={multiple_words}")
assert multi_word_search.status_code == 200


@pytest.mark.parametrize("query, expected", valid_queries)
def test_valid_pubmed_queries(query, expected, auth_client, ingest_neurosynth, session):
search = auth_client.get(f"/api/studies/?search={query}")
assert search.status_code == 200


@pytest.mark.parametrize("query, expected", invalid_queries)
def test_invalid_pubmed_queries(
query, expected, auth_client, ingest_neurosynth, session
):
search = auth_client.get(f"/api/studies/?search={query}")
assert search.status_code == 400
65 changes: 65 additions & 0 deletions store/neurostore/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,3 +586,68 @@ def simple_neurosynth_annotation(session, ingest_neurosynth):
session.commit()

return smol_annot


"""
Queries for testing
"""
invalid_queries = [
(
'("autism" OR "ASD" OR "autistic") AND (("decision*" OR "choice*" ',
"Unmatched parentheses",
),
('"autism" OR "ASD" OR "autistic" OR ', "Query cannot end with an operator"),
(
'(("Autism Spectrum Disorder" OR "autism spectrum disorder") OR ("Autism" OR "autism") '
'OR ("ASD")) AND (("decision*" OR "Dec',
"Unmatched parentheses",
),
]

valid_queries = [
(
'"Mild Cognitive Impairment" or "Early Cognitive Decline" or "Pre-Dementia" or '
'"Mild Neurocognitive Disorder"',
"MILD<->COGNITIVE<->IMPAIRMENT | EARLY<->COGNITIVE<->DECLINE | PRE<->DEMENTIA | "
"MILD<->NEUROCOGNITIVE<->DISORDER",
),
(
'("autism" OR "ASD" OR "autistic") AND ("decision" OR "choice")',
"(AUTISM | ASD | AUTISTIC) & (DECISION | CHOICE)",
),
(
"stroop and depression or back and depression or go",
"STROOP & DEPRESSION | BACK & DEPRESSION | GO",
),
(
'("autism" OR "ASD" OR "autistic") AND (("decision" OR "decision-making" OR "choice" OR '
'"selection" OR "option" OR "value") OR ("feedback" OR "feedback-related" OR "reward" OR '
'"error" OR "outcome" OR "punishment" OR "reinforcement"))',
"(AUTISM | ASD | AUTISTIC) & ((DECISION | DECISION<->MAKING | CHOICE | SELECTION | OPTION "
"| VALUE) | (FEEDBACK | FEEDBACK<->RELATED | REWARD | ERROR | OUTCOME | PUNISHMENT | "
"REINFORCEMENT))",
),
(
'"dyslexia" or "Reading Disorder" or "Language-Based Learning Disability" or '
'"Phonological Processing Disorder" or "Word Blindness"',
"DYSLEXIA | READING<->DISORDER | LANGUAGE<->BASED<->LEARNING<->DISABILITY | "
"PHONOLOGICAL<->PROCESSING<->DISORDER | WORD<->BLINDNESS",
),
("emotion and pain -physical -touch", "EMOTION & PAIN & -PHYSICAL & -TOUCH"),
(
'("Schizophrenia"[Mesh] OR schizophrenia )',
"(SCHIZOPHRENIA & MESH | SCHIZOPHRENIA)",
),
("Bipolar Disorder", "BIPOLAR & DISORDER"),
('"quchi" or "LI11"', "QUCHI | LI11"),
('"rubber hand illusion"', "RUBBER<->HAND<->ILLUSION"),
]

weird_queries = [
(
"[Major Depressive Disorder (MDD)] or [Clinical Depression] or [Unipolar Depression]",
"MAJOR & DEPRESSIVE & DISORDER & (MDD) | CLINICAL & DEPRESSION | UNIPOLAR & DEPRESSION",
),
]

validate_queries = invalid_queries + [(q, True) for q, _ in valid_queries]
23 changes: 23 additions & 0 deletions store/neurostore/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from ..resources.utils import pubmed_to_tsquery, validate_search_query
from .conftest import valid_queries, validate_queries, weird_queries


@pytest.mark.parametrize("query, expected", valid_queries)
def test_pubmed_to_tsquery(query, expected):
assert pubmed_to_tsquery(query) == expected


@pytest.mark.parametrize("query, expected", validate_queries)
def test_validate_search_query(query, expected):
if expected is True:
assert validate_search_query(query) == expected
else:
with pytest.raises(Exception):
validate_search_query(query)


@pytest.mark.parametrize("query, expected", weird_queries)
def test_pubmed_to_tsquery_weird(query, expected):
assert pubmed_to_tsquery(query) == expected

0 comments on commit ea5cdc4

Please sign in to comment.