Skip to content

Commit

Permalink
Add Solr edismax helper
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Dec 26, 2023
1 parent 0ab4b00 commit 43f8216
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 0 deletions.
153 changes: 153 additions & 0 deletions searcharray/solr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""Utility functions for Solr users of searcharray."""
import re
import pandas as pd
import numpy as np
from typing import List, Optional
from searcharray.postings import PostingsArray


def parse_min_should_match(num_clauses: int, spec: str) -> int:
"""Parse Solr's min should match (ie mm) spec.
See this ChatGPT translation of mm code from Solr's Java code for parsing this
https://chat.openai.com/share/76642aec-7e05-420f-a53a-83b8e2eea8fb
Parameters
----------
num_clauses : int
spec : str
Returns
-------
int : the number of clauses that must match
"""
def checked_parse_int(value, error_message):
try:
return int(value)
except ValueError:
raise ValueError(error_message)

result = num_clauses
spec = spec.strip()

if '<' in spec:
# we have conditional spec(s)
space_around_less_than_pattern = re.compile(r'\s*<\s*')
spec = space_around_less_than_pattern.sub('<', spec)
for s in spec.split():
parts = s.split('<', 1)
if len(parts) < 2:
raise ValueError("Invalid 'mm' spec: '" + s + "'. Expecting values before and after '<'")
upper_bound = checked_parse_int(parts[0], "Invalid 'mm' spec. Expecting an integer.")
if num_clauses <= upper_bound:
return result
else:
result = parse_min_should_match(num_clauses, parts[1])
return result

# otherwise, simple expression
if '%' in spec:
# percentage - assume the % was the last char. If not, let int() fail.
spec = spec[:-1]
percent = checked_parse_int(spec, "Invalid 'mm' spec. Expecting an integer.")
calc = (result * percent) * (1 / 100)
result = result + int(calc) if calc < 0 else int(calc)
else:
calc = checked_parse_int(spec, "Invalid 'mm' spec. Expecting an integer.")
result = result + calc if calc < 0 else calc

return min(num_clauses, max(result, 0))


def parse_field_boosts(field_lists: List[str]) -> dict:
"""Parse Solr's qf, pf, pf2, pf3 field boosts."""
if not field_lists:
return {}

out = {}
carat_pattern = re.compile(r'\^')

for field in field_lists:
parts = carat_pattern.split(field)
out[parts[0]] = None if len(parts) == 1 else float(parts[1])

return out


def edismax(frame: pd.DataFrame,
q: str,
qf: List[str],
mm: Optional[str] = None,
pf: Optional[List[str]] = None,
pf2: Optional[List[str]] = None,
pf3: Optional[List[str]] = None,
q_op: str = "OR") -> np.ndarray:
"""Run edismax search over dataframe with searcharray fields.
Parameters
----------
q : str
The query string
mm : str
The minimum should match spec
qf : list
The fields to search
pf : list
The fields to search for phrase matches
pf2 : list
The fields to search for bigram matches
pf3 : list
The fields to search for trigram matches
q_op : str, optional
The default operator, by default "OR"
Returns
-------
np.ndarray
The search results
"""
terms = q.split()
query_fields = parse_field_boosts(qf)
phrase_fields = parse_field_boosts(pf) if pf else {}
# bigram_fields = parse_field_boosts(pf2) if pf2 else {}
# trigram_fields = parse_field_boosts(pf3) if pf3 else {}

def check_field(frame, field):
if field not in frame.columns:
raise ValueError(f"Field {field} not in dataframe")
if not isinstance(frame[field].array, PostingsArray):
raise ValueError(f"Field {field} is not a searcharray field")

term_scores = []
for term in terms:
max_scores = np.zeros(len(frame))
for field, boost in query_fields.items():
check_field(frame, field)
field_term_score = frame[field].array.bm25(term) * (1 if boost is None else boost)
max_scores = np.maximum(max_scores, field_term_score)
term_scores.append(max_scores)

if mm is None:
mm = "1"
if q_op == "AND":
mm = "100%"

min_should_match = parse_min_should_match(len(terms), spec=mm)
qf_scores = np.asarray(term_scores)
matches_gt_mm = np.sum(qf_scores > 0, axis=0) >= min_should_match
qf_scores = np.sum(term_scores, axis=0)
qf_scores[~matches_gt_mm] = 0

phrase_scores = []
for field, boost in phrase_fields.items():
check_field(frame, field)
field_phrase_score = frame[field].array.bm25(terms) * (1 if boost is None else boost)
phrase_scores.append(field_phrase_score)

if len(phrase_scores) > 0:
phrase_scores = np.sum(phrase_scores, axis=0)
# Add where term_scores > 0
term_match_idx = np.where(qf_scores)[0]

qf_scores[term_match_idx] += phrase_scores[term_match_idx]
return qf_scores
111 changes: 111 additions & 0 deletions test/test_solr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Tests for solr dsl helpers."""
import pytest
from test_utils import w_scenarios
import pandas as pd
import numpy as np
import numbers

from searcharray.solr import parse_min_should_match, edismax
from searcharray.postings import PostingsArray


def test_standard_percentage():
assert parse_min_should_match(10, "50%") == 5


def test_over_100_percentage():
assert parse_min_should_match(10, "150%") == 10


def test_negative_percentage():
assert parse_min_should_match(10, "-50%") == 5


def test_standard_integer():
assert parse_min_should_match(10, "3") == 3


def test_negative_integer():
assert parse_min_should_match(10, "-3") == 7


def test_integer_exceeding_clause_count():
assert parse_min_should_match(10, "15") == 10


def test_conditional_spec_less_than_clause_count():
assert parse_min_should_match(10, "5<70%") == 7


def test_conditional_spec_greater_than_clause_count():
assert parse_min_should_match(10, "15<70%") == 10


def test_complex_conditional_spec():
assert parse_min_should_match(10, "3<50% 5<30%") == 3


def test_invalid_spec_percentage():
with pytest.raises(ValueError):
parse_min_should_match(10, "five%")


def test_invalid_spec_integer():
with pytest.raises(ValueError):
parse_min_should_match(10, "five")


def test_invalid_spec_conditional():
with pytest.raises(ValueError):
parse_min_should_match(10, "5<")


def test_empty_spec():
with pytest.raises(ValueError):
parse_min_should_match(10, "")


def test_complex_conditional_spec_with_percentage():
assert parse_min_should_match(10, "2<2 5<3 7<40%") == 4


edismax_scenarios = {
"base": {
"frame": {
'title': lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"]),
'body': lambda: PostingsArray.index(["buzz", "data2", "data3 bar", "bunny funny wunny"])
},
"expected": [lambda frame: sum([frame['title'].array.bm25("foo")[0],
frame['title'].array.bm25("bar")[0]]),
0,
lambda frame: max(frame['title'].array.bm25("bar")[2],
frame['body'].array.bm25("bar")[2]),
0],
"params": {'q': "foo bar", 'qf': ["title", "body"]},
},
}


def build_df(frame):
for k, v in frame.items():
if hasattr(v, '__call__'):
frame[k] = v()
frame = pd.DataFrame(frame)
return frame


def compute_expected(expected, frame):
for idx, exp in enumerate(expected):
if hasattr(exp, '__call__'):
comp_expected = exp(frame)
yield comp_expected
else:
yield exp


@w_scenarios(edismax_scenarios)
def test_edismax(frame, expected, params):
frame = build_df(frame)
expected = list(compute_expected(expected, frame))
scores = edismax(frame, **params)
assert np.allclose(scores, expected)

0 comments on commit 43f8216

Please sign in to comment.