Skip to content

Commit

Permalink
Add field centric fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Dec 28, 2023
1 parent c9af718 commit 0c353cb
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 21 deletions.
73 changes: 52 additions & 21 deletions searcharray/solr.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def parse_query_terms(frame: pd.DataFrame,

search_terms: Dict[str, List[str]] = {}
num_search_terms = 0
term_centric = True

for field in query_fields:
arr = get_field(frame, field)
Expand All @@ -101,9 +102,50 @@ def parse_query_terms(frame: pd.DataFrame,
if num_search_terms == 0:
num_search_terms = field_num_search_terms
elif field_num_search_terms != num_search_terms:
raise ValueError("All qf field tokenizers must emit the same number of terms")
term_centric = False

return num_search_terms, search_terms
return num_search_terms, search_terms, term_centric


def _edismax_term_centric(frame: pd.DataFrame,
query_fields: Dict[str, float],
num_search_terms: int,
search_terms: Dict[str, List[str]],
min_should_match: int):
term_scores = []
for term_posn in range(num_search_terms):
max_scores = np.zeros(len(frame))
for field, boost in query_fields.items():
term = search_terms[field][term_posn]
post_arr = get_field(frame, field)
field_term_score = post_arr.bm25(term) * (1 if boost is None else boost)
max_scores = np.maximum(max_scores, field_term_score)
term_scores.append(max_scores)

qf_scores = np.asarray(term_scores)
matches_gt_mm = np.sum(qf_scores > 0, axis=0) >= min_should_match
qf_scores = np.sum(term_scores, axis=0)
qf_scores[~matches_gt_mm] = 0
return qf_scores


def _edismax_field_centric(frame: pd.DataFrame,
query_fields: Dict[str, float],
num_search_terms: int,
search_terms: Dict[str, List[str]],
min_should_match: int) -> np.ndarray:
field_scores = []
for field, boost in query_fields.items():
post_arr = get_field(frame, field)
term_scores = np.array([post_arr.bm25(term) for term in search_terms[field]])
matches_gt_mm = np.sum(term_scores > 0, axis=1) >= min_should_match
term_scores[~matches_gt_mm] = 0
sum_terms_bm25 = np.sum(term_scores, axis=0)
field_scores.append(sum_terms_bm25 * (1 if boost is None else boost))
# Take maximum field scores as qf
qf_scores = np.asarray(field_scores)
qf_scores = np.max(qf_scores, axis=0)
return qf_scores


def edismax(frame: pd.DataFrame,
Expand Down Expand Up @@ -143,31 +185,20 @@ def listify(x):

query_fields = parse_field_boosts(listify(qf))
phrase_fields = parse_field_boosts(listify(pf)) if pf else {}

# bigram_fields = parse_field_boosts(pf2) if pf2 else {}
# trigram_fields = parse_field_boosts(pf3) if pf3 else {}

num_search_terms, search_terms = parse_query_terms(frame, q, list(query_fields.keys()))

term_scores = []
for term_posn in range(num_search_terms):
max_scores = np.zeros(len(frame))
for field, boost in query_fields.items():
term = search_terms[field][term_posn]
field_term_score = frame[field].array.bm25(term) * (1 if boost is None else boost)
max_scores = np.maximum(max_scores, field_term_score)
term_scores.append(max_scores)

if mm is None:
mm = "1"
if q_op == "AND":
mm = "100%"

# bigram_fields = parse_field_boosts(pf2) if pf2 else {}
# trigram_fields = parse_field_boosts(pf3) if pf3 else {}

num_search_terms, search_terms, term_centric = parse_query_terms(frame, q, list(query_fields.keys()))
min_should_match = parse_min_should_match(num_search_terms, spec=mm)
qf_scores = np.asarray(term_scores)
matches_gt_mm = np.sum(qf_scores > 0, axis=0) >= min_should_match
qf_scores = np.sum(term_scores, axis=0)
qf_scores[~matches_gt_mm] = 0
if term_centric:
qf_scores = _edismax_term_centric(frame, query_fields, num_search_terms, search_terms, min_should_match)
else:
qf_scores = _edismax_field_centric(frame, query_fields, num_search_terms, search_terms, min_should_match)

phrase_scores = []
for field, boost in phrase_fields.items():
Expand Down
19 changes: 19 additions & 0 deletions test/test_solr.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def everythings_a_b_tokenizer(text: str) -> List[str]:
return ["b"] * len(text.split())


def just_lowercasing_tokenizer(text: str) -> List[str]:
"""Lowercase and return a list of tokens."""
return [text.lower()]


edismax_scenarios = {
"base": {
"frame": {
Expand All @@ -88,6 +93,20 @@ def everythings_a_b_tokenizer(text: str) -> List[str]:
0],
"params": {'q': "foo bar", 'qf': ["title", "body"]},
},
"field_centric": {
"frame": {
'title': lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"]),
'body': lambda: PostingsArray.index(["foo bar", "data2", "data3 bar", "bunny funny wunny"],
tokenizer=just_lowercasing_tokenizer)
},
"expected": [lambda frame: max(sum([frame['title'].array.bm25("foo")[0],
frame['title'].array.bm25("bar")[0]]),
frame['body'].array.bm25("foo bar")[0]),
0,
lambda frame: frame['title'].array.bm25("bar")[2],
0],
"params": {'q': "foo bar", 'qf': ["title", "body"]},
},
"boost_title": {
"frame": {
'title': lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"]),
Expand Down

0 comments on commit 0c353cb

Please sign in to comment.