Skip to content

Commit

Permalink
v0.0.21
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Dec 27, 2023
1 parent 89bf685 commit 12908c2
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 14 deletions.
55 changes: 42 additions & 13 deletions searcharray/solr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
import pandas as pd
import numpy as np
from typing import List, Optional
from typing import List, Optional, Dict
from searcharray.postings import PostingsArray


Expand Down Expand Up @@ -74,6 +74,38 @@ def parse_field_boosts(field_lists: List[str]) -> dict:
return out


def get_field(frame, field) -> PostingsArray:
if field not in frame.columns:
raise ValueError(f"Field {field} not in dataframe")
if not isinstance(frame[field].array, PostingsArray):
raise ValueError(f"Field {field} is not a searcharray field")
return frame[field].array


def parse_query_terms(frame: pd.DataFrame,
query: str,
query_fields: List[str]):

search_terms: Dict[str, List[str]] = {}
num_search_terms = 0

for field in query_fields:
arr = get_field(frame, field)

tokenizer = arr.tokenizer
search_terms[field] = []
field_num_search_terms = 0
for posn, term in enumerate(tokenizer(query)):
search_terms[field].append(term)
field_num_search_terms += 1
if num_search_terms == 0:
num_search_terms = field_num_search_terms
elif field_num_search_terms != num_search_terms:
raise ValueError("All qf field tokenizers must emit the same number of terms")

return num_search_terms, search_terms


def edismax(frame: pd.DataFrame,
q: str,
qf: List[str],
Expand Down Expand Up @@ -106,26 +138,22 @@ def edismax(frame: pd.DataFrame,
np.ndarray
The search results
"""
terms = q.split()

def listify(x):
return x if isinstance(x, list) else [x]

query_fields = parse_field_boosts(listify(qf))
phrase_fields = parse_field_boosts(listify(pf)) if pf else {}

# bigram_fields = parse_field_boosts(pf2) if pf2 else {}
# trigram_fields = parse_field_boosts(pf3) if pf3 else {}

def check_field(frame, field):
if field not in frame.columns:
raise ValueError(f"Field {field} not in dataframe")
if not isinstance(frame[field].array, PostingsArray):
raise ValueError(f"Field {field} is not a searcharray field")
num_search_terms, search_terms = parse_query_terms(frame, q, qf)

term_scores = []
for term in terms:
for term_posn in range(num_search_terms):
max_scores = np.zeros(len(frame))
for field, boost in query_fields.items():
check_field(frame, field)
term = search_terms[field][term_posn]
field_term_score = frame[field].array.bm25(term) * (1 if boost is None else boost)
max_scores = np.maximum(max_scores, field_term_score)
term_scores.append(max_scores)
Expand All @@ -135,16 +163,17 @@ def check_field(frame, field):
if q_op == "AND":
mm = "100%"

min_should_match = parse_min_should_match(len(terms), spec=mm)
min_should_match = parse_min_should_match(num_search_terms, spec=mm)
qf_scores = np.asarray(term_scores)
matches_gt_mm = np.sum(qf_scores > 0, axis=0) >= min_should_match
qf_scores = np.sum(term_scores, axis=0)
qf_scores[~matches_gt_mm] = 0

phrase_scores = []
for field, boost in phrase_fields.items():
check_field(frame, field)
field_phrase_score = frame[field].array.bm25(terms) * (1 if boost is None else boost)
arr = get_field(frame, field)
terms = search_terms[field]
field_phrase_score = arr.bm25(terms) * (1 if boost is None else boost)
phrase_scores.append(field_phrase_score)

if len(phrase_scores) > 0:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
# For a discussion on single-sourcing the version across setup.py and the
# project code, see
# https://packaging.python.org/guides/single-sourcing-package-version/
version="0.0.20", # Required
version="0.0.21", # Required
# This is a one-line description or tagline of what your project does. This
# corresponds to the "Summary" metadata field:
# https://packaging.python.org/specifications/core-metadata/#summary
Expand Down
23 changes: 23 additions & 0 deletions test/test_solr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for solr dsl helpers."""
import pytest
from typing import List
from test_utils import w_scenarios
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -68,6 +69,11 @@ def test_complex_conditional_spec_with_percentage():
assert parse_min_should_match(10, "2<2 5<3 7<40%") == 4


def everythings_a_b_tokenizer(text: str) -> List[str]:
"""Split on whitespace and return a list of tokens."""
return ["b"] * len(text.split())


edismax_scenarios = {
"base": {
"frame": {
Expand Down Expand Up @@ -97,6 +103,23 @@ def test_complex_conditional_spec_with_percentage():
"params": {'q': "foo bar", 'qf': ["title", "body"],
'pf': ["title"]}
},
"different_analyzers": {
"frame": {
'title': lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"]),
'body': lambda: PostingsArray.index(["buzz", "data2", "data3 bar", "bunny funny wunny"],
tokenizer=everythings_a_b_tokenizer)
},
"expected": [lambda frame: max(frame['title'].array.bm25("bar")[0],
frame['body'].array.bm25("b")[0]),

lambda frame: frame['body'].array.bm25("b")[1],

lambda frame: max(frame['title'].array.bm25("bar")[2],
frame['body'].array.bm25("b")[2]),

lambda frame: frame['body'].array.bm25("b")[3]],
"params": {'q': "bar", 'qf': ["title", "body"]},
},
}


Expand Down

0 comments on commit 12908c2

Please sign in to comment.