From ea5cdc4c3ba217f39d7cf16c07c6de2d64058275 Mon Sep 17 00:00:00 2001 From: Alejandro de la Vega Date: Mon, 11 Nov 2024 12:59:33 -0600 Subject: [PATCH] ENH: Add custom tsquery from websearch function and related tests (#838) * Add custom tsquery from websearch function and related tests * update tests and logic * fix style issues --------- Co-authored-by: James Kent --- store/neurostore/resources/base.py | 10 +- store/neurostore/resources/utils.py | 181 ++++++++++++++++++ .../neurostore/tests/api/test_query_params.py | 15 ++ store/neurostore/tests/conftest.py | 65 +++++++ store/neurostore/tests/test_utils.py | 23 +++ 5 files changed, 292 insertions(+), 2 deletions(-) create mode 100644 store/neurostore/tests/test_utils.py diff --git a/store/neurostore/resources/base.py b/store/neurostore/resources/base.py index dc18f441..b5ca5f72 100644 --- a/store/neurostore/resources/base.py +++ b/store/neurostore/resources/base.py @@ -8,6 +8,8 @@ from flask import abort, request, current_app # jsonify from flask.views import MethodView +from psycopg2 import errors + import sqlalchemy as sa import sqlalchemy.sql.expression as sae from sqlalchemy.orm import ( @@ -21,7 +23,7 @@ from ..core import cache from ..database import db -from .utils import get_current_user +from .utils import get_current_user, validate_search_query, pubmed_to_tsquery from ..models import ( StudysetStudy, AnnotationAnalysis, @@ -613,7 +615,11 @@ def search(self): if s is not None and s.isdigit(): q = q.filter_by(pmid=s) elif s is not None and self._fulltext_fields: - tsquery = sa.func.websearch_to_tsquery("english", s) + try: + validate_search_query(s) + except errors.SyntaxError as e: + abort(400, description=e.args[0]) + tsquery = pubmed_to_tsquery(s) q = q.filter(m._ts_vector.op("@@")(tsquery)) # Alternatively (or in addition), search on individual fields. diff --git a/store/neurostore/resources/utils.py b/store/neurostore/resources/utils.py index 9fa0fc68..08974e35 100644 --- a/store/neurostore/resources/utils.py +++ b/store/neurostore/resources/utils.py @@ -5,6 +5,7 @@ import re from connexion.context import context +from psycopg2 import errors from .. import models from .. import schemas @@ -44,3 +45,183 @@ class ClassView(cls): ClassView.__name__ = cls.__name__ return ClassView + + +def validate_search_query(query: str) -> bool: + """ + Validate a search query string. + + Args: + query (str): The query string to validate. + + Returns: + bool: True if the query is valid, False otherwise. + """ + # Check for valid parentheses + if not validate_parentheses(query): + raise errors.SyntaxError("Unmatched parentheses") + + # Check for valid query end + if not validate_query_end(query): + raise errors.SyntaxError("Query cannot end with an operator") + + return True + + +def validate_parentheses(query: str) -> bool: + """ + Validate the parentheses in a query string. + + Args: + query (str): The query string to validate. + + Returns: + bool: True if parentheses are valid, False otherwise. + """ + stack = [] + for char in query: + if char == "(": + stack.append(char) + elif char == ")": + if not stack: + return False # Unmatched closing parenthesis + stack.pop() + return not stack # Ensure all opening parentheses are closed + + +def validate_query_end(query: str) -> bool: + """Query should not end with an operator""" + operators = ("AND", "OR", "NOT") + + if query.strip().split(" ")[-1] in operators: + return False + return True + + +def count_chars(target, query: str) -> int: + """Count the number of chars in a query string. + Excluding those in quoted phrases.""" + count = 0 + in_quotes = False + for char in query: + if char == '"': + in_quotes = not in_quotes + if char == target and not in_quotes: + count += 1 + return count + + +def pubmed_to_tsquery(query: str) -> str: + """ + Convert a PubMed-like search query to PostgreSQL tsquery format, + grouping both single-quoted and double-quoted text with the <-> operator + for proximity search. + + Additionally, automatically adds & between non-explicitly connected terms + and handles NOT terms. + + Args: + query (str): The search query. + + Returns: + str: The PostgreSQL tsquery equivalent. + """ + + query = query.upper() # Ensure uniformity + + # Step 1: Split into tokens (preserving quoted phrases) + # Regex pattern: match quoted phrases or non-space sequences + tokens = re.findall(r'"[^"]*"|\'[^\']*\'|\S+', query) + + # Step 2: Combine tokens in parantheses into single tokens + def combine_parentheses(tokens: list) -> list: + """ + Combine tokens within parentheses into a single token. + + Args: + tokens (list): List of tokens to process. + + Returns: + list: Processed list with tokens inside parentheses combined. + """ + combined_tokens = [] + buffer = [] + paren_count = 0 + for token in tokens: + # If buffer is not empty, we are inside parentheses + if len(buffer) > 0: + buffer.append(token) + + # Adjust the count of parentheses + paren_count += count_chars("(", token) - count_chars(")", token) + + if paren_count < 1: + # Combine all tokens in parentheses + combined_tokens.append(" ".join(buffer)) + buffer = [] # Clear the buffer + paren_count = 0 + + else: + n_paren = count_chars("(", token) - count_chars(")", token) + # If not in parentheses, but token contains opening parentheses + # Start capturing tokens inside parentheses + if token[0] == "(" and n_paren > 0: + paren_count += n_paren + buffer.append(token) # Start capturing tokens in parens + print(buffer) + else: + combined_tokens.append(token) + + # If the list ends without a closing parenthesis (invalid input) + # append buffer contents (fallback) + if buffer: + combined_tokens.append(" ".join(buffer)) + + return combined_tokens + + tokens = combine_parentheses(tokens) + print(tokens) + for i, token in enumerate(tokens): + if token[0] == "(" and token[-1] == ")": + # RECURSIVE: Process the contents of the parentheses + token_res = pubmed_to_tsquery(token[1:-1]) + token = "(" + token_res + ")" + tokens[i] = token + + # Step 4: Handle both single-quoted and double-quoted phrases, + # grouping them with <-> (proximity operator) + elif token[0] in ('"', "'"): + # Split quoted text into individual words and join with <-> for + # proximity search + words = re.findall(r"\w+", token) + tokens[i] = "<->".join(words) + + # Step 3: Replace logical operators AND, OR, NOT + else: + if token == "AND": + tokens[i] = "&" + elif token == "OR": + tokens[i] = "|" + elif token == "NOT": + tokens[i] = "&!" + + processed_tokens = [] + last_token = None + for token in tokens: + # Step 5: Add & between consecutive terms that aren't already + # connected by an operator + stripped_token = token.strip() + if stripped_token not in ("&", "|", "!", "&!"): + stripped_token = re.sub(r"[\[\],;:!?@#]", "", stripped_token) + if stripped_token == "": + continue # Ignore empty tokens from splitting + + if last_token and last_token not in ("&", "|", "!", "&!"): + if stripped_token not in ("&", "|", "!", "&!"): + # Insert an implicit AND (&) between two non-operator tokens + processed_tokens.append("&") + + processed_tokens.append(stripped_token) + last_token = stripped_token + + return " ".join(processed_tokens) diff --git a/store/neurostore/tests/api/test_query_params.py b/store/neurostore/tests/api/test_query_params.py index 5d8c0500..ad8f2bd5 100644 --- a/store/neurostore/tests/api/test_query_params.py +++ b/store/neurostore/tests/api/test_query_params.py @@ -1,6 +1,7 @@ import pytest from ...models import Study from ...schemas.data import StudysetSchema, StudySchema, AnalysisSchema, StringOrNested +from ..conftest import valid_queries, invalid_queries @pytest.mark.parametrize("nested", ["true", "false"]) @@ -99,3 +100,17 @@ def test_multiword_queries(auth_client, ingest_neurosynth, session): multi_word_search = auth_client.get(f"/api/studies/?search={multiple_words}") assert multi_word_search.status_code == 200 + + +@pytest.mark.parametrize("query, expected", valid_queries) +def test_valid_pubmed_queries(query, expected, auth_client, ingest_neurosynth, session): + search = auth_client.get(f"/api/studies/?search={query}") + assert search.status_code == 200 + + +@pytest.mark.parametrize("query, expected", invalid_queries) +def test_invalid_pubmed_queries( + query, expected, auth_client, ingest_neurosynth, session +): + search = auth_client.get(f"/api/studies/?search={query}") + assert search.status_code == 400 diff --git a/store/neurostore/tests/conftest.py b/store/neurostore/tests/conftest.py index 099ff129..1189e214 100644 --- a/store/neurostore/tests/conftest.py +++ b/store/neurostore/tests/conftest.py @@ -586,3 +586,68 @@ def simple_neurosynth_annotation(session, ingest_neurosynth): session.commit() return smol_annot + + +""" +Queries for testing +""" +invalid_queries = [ + ( + '("autism" OR "ASD" OR "autistic") AND (("decision*" OR "choice*" ', + "Unmatched parentheses", + ), + ('"autism" OR "ASD" OR "autistic" OR ', "Query cannot end with an operator"), + ( + '(("Autism Spectrum Disorder" OR "autism spectrum disorder") OR ("Autism" OR "autism") ' + 'OR ("ASD")) AND (("decision*" OR "Dec', + "Unmatched parentheses", + ), +] + +valid_queries = [ + ( + '"Mild Cognitive Impairment" or "Early Cognitive Decline" or "Pre-Dementia" or ' + '"Mild Neurocognitive Disorder"', + "MILD<->COGNITIVE<->IMPAIRMENT | EARLY<->COGNITIVE<->DECLINE | PRE<->DEMENTIA | " + "MILD<->NEUROCOGNITIVE<->DISORDER", + ), + ( + '("autism" OR "ASD" OR "autistic") AND ("decision" OR "choice")', + "(AUTISM | ASD | AUTISTIC) & (DECISION | CHOICE)", + ), + ( + "stroop and depression or back and depression or go", + "STROOP & DEPRESSION | BACK & DEPRESSION | GO", + ), + ( + '("autism" OR "ASD" OR "autistic") AND (("decision" OR "decision-making" OR "choice" OR ' + '"selection" OR "option" OR "value") OR ("feedback" OR "feedback-related" OR "reward" OR ' + '"error" OR "outcome" OR "punishment" OR "reinforcement"))', + "(AUTISM | ASD | AUTISTIC) & ((DECISION | DECISION<->MAKING | CHOICE | SELECTION | OPTION " + "| VALUE) | (FEEDBACK | FEEDBACK<->RELATED | REWARD | ERROR | OUTCOME | PUNISHMENT | " + "REINFORCEMENT))", + ), + ( + '"dyslexia" or "Reading Disorder" or "Language-Based Learning Disability" or ' + '"Phonological Processing Disorder" or "Word Blindness"', + "DYSLEXIA | READING<->DISORDER | LANGUAGE<->BASED<->LEARNING<->DISABILITY | " + "PHONOLOGICAL<->PROCESSING<->DISORDER | WORD<->BLINDNESS", + ), + ("emotion and pain -physical -touch", "EMOTION & PAIN & -PHYSICAL & -TOUCH"), + ( + '("Schizophrenia"[Mesh] OR schizophrenia )', + "(SCHIZOPHRENIA & MESH | SCHIZOPHRENIA)", + ), + ("Bipolar Disorder", "BIPOLAR & DISORDER"), + ('"quchi" or "LI11"', "QUCHI | LI11"), + ('"rubber hand illusion"', "RUBBER<->HAND<->ILLUSION"), +] + +weird_queries = [ + ( + "[Major Depressive Disorder (MDD)] or [Clinical Depression] or [Unipolar Depression]", + "MAJOR & DEPRESSIVE & DISORDER & (MDD) | CLINICAL & DEPRESSION | UNIPOLAR & DEPRESSION", + ), +] + +validate_queries = invalid_queries + [(q, True) for q, _ in valid_queries] diff --git a/store/neurostore/tests/test_utils.py b/store/neurostore/tests/test_utils.py new file mode 100644 index 00000000..1ac8359c --- /dev/null +++ b/store/neurostore/tests/test_utils.py @@ -0,0 +1,23 @@ +import pytest + +from ..resources.utils import pubmed_to_tsquery, validate_search_query +from .conftest import valid_queries, validate_queries, weird_queries + + +@pytest.mark.parametrize("query, expected", valid_queries) +def test_pubmed_to_tsquery(query, expected): + assert pubmed_to_tsquery(query) == expected + + +@pytest.mark.parametrize("query, expected", validate_queries) +def test_validate_search_query(query, expected): + if expected is True: + assert validate_search_query(query) == expected + else: + with pytest.raises(Exception): + validate_search_query(query) + + +@pytest.mark.parametrize("query, expected", weird_queries) +def test_pubmed_to_tsquery_weird(query, expected): + assert pubmed_to_tsquery(query) == expected