Skip to content

Commit

Permalink
Add tie param
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Apr 19, 2024
1 parent a3b10d2 commit 9bf53d9
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 2 deletions.
17 changes: 15 additions & 2 deletions searcharray/solr.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,25 @@ def _edismax_term_centric(frame: pd.DataFrame,
num_search_terms: int,
search_terms: Dict[str, List[str]],
mm: str,
tie: float,
similarity: Dict[str, Similarity]) -> Tuple[np.ndarray, str]:
explain = []
term_scores = []
for term_posn in range(num_search_terms):
max_scores = np.zeros(len(frame))
sum_scores = np.zeros(len(frame))
term_explain = []
for field, boost in query_fields.items():
term = search_terms[field][term_posn]
post_arr = get_field(frame, field)
field_term_score = post_arr.score(term, similarity=similarity[field]) * (1 if boost is None else boost)
boost_exp = f"{boost}" if boost is not None else "1"
term_explain.append(f"{field}:{term}^{boost_exp}")
sum_scores += field_term_score
max_scores = np.maximum(max_scores, field_term_score)
term_scores.append(max_scores)

remainder_scores = sum_scores - max_scores
term_scores.append(max_scores + remainder_scores * tie)
explain.append("(" + " | ".join(term_explain) + ")")

min_should_match = parse_min_should_match(num_search_terms, spec=mm)
Expand All @@ -142,6 +147,7 @@ def _edismax_field_centric(frame: pd.DataFrame,
num_search_terms: int,
search_terms: Dict[str, List[str]],
mm: str,
tie: float,
similarity: Dict[str, Similarity]) -> Tuple[np.ndarray, str]:
field_scores = []
explain = []
Expand All @@ -162,8 +168,10 @@ def _edismax_field_centric(frame: pd.DataFrame,
explain.append(exp)
# Take maximum field scores as qf
qf_scores = np.asarray(field_scores)
summed_scores = np.sum(qf_scores, axis=0)
qf_scores = np.max(qf_scores, axis=0)
return qf_scores, " | ".join(explain)
qf_with_tie_scores = qf_scores + (summed_scores - qf_scores) * tie
return qf_with_tie_scores, " | ".join(explain)


def edismax(frame: pd.DataFrame,
Expand All @@ -173,6 +181,7 @@ def edismax(frame: pd.DataFrame,
pf: Optional[List[str]] = None,
pf2: Optional[List[str]] = None,
pf3: Optional[List[str]] = None,
tie: float = 0.0,
q_op: str = "OR",
similarity: Union[Similarity, Dict[str, Similarity]] = default_bm25) -> Tuple[np.ndarray, str]:
"""Run edismax search over dataframe with searcharray fields.
Expand All @@ -193,6 +202,8 @@ def edismax(frame: pd.DataFrame,
The fields to search for trigram matches
q_op : str, optional
The default operator, by default "OR"
tie : float, optional
The tie breaker, by default 0.0
similarity : Union[Similarity, Dict[str, Similarity]], optional
The similarity to use per field, by default default_bm25
Expand Down Expand Up @@ -227,10 +238,12 @@ def listify(x):
if term_centric:
qf_scores, explain = _edismax_term_centric(frame, query_fields,
num_search_terms, search_terms, mm,
tie=tie,
similarity=similarity)
else:
qf_scores, explain = _edismax_field_centric(frame, query_fields,
num_search_terms, search_terms, mm,
tie=tie,
similarity=similarity)

phrase_scores = []
Expand Down
31 changes: 31 additions & 0 deletions test/test_solr.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ def just_lowercasing_tokenizer(text: str) -> List[str]:
0],
"params": {'q': "foo bar", 'qf': ["title", "body"]},
},
"field_centric_tie": {
"frame": {
'title': lambda: SearchArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"]),
'body': lambda: SearchArray.index(["foo bar", "data2", "data3 bar", "bunny funny wunny"],
tokenizer=just_lowercasing_tokenizer)
},
"expected": [lambda frame: sum([sum([frame['title'].array.score("foo")[0],
frame['title'].array.score("bar")[0]]),
0.1 * frame['body'].array.score("foo bar")[0]]),
0,
lambda frame: frame['title'].array.score("bar")[2],
0],
"params": {'q': "foo bar", 'qf': ["title", "body"], 'tie': 0.1},
},
"field_centric_mm": {
"frame": {
'title': lambda: SearchArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"]),
Expand Down Expand Up @@ -163,6 +177,17 @@ def just_lowercasing_tokenizer(text: str) -> List[str]:
"params": {'q': "foo bar", 'qf': ["title", "body"],
'pf': ["title"]}
},
"with_tie": {
"frame": {
'title': lambda: SearchArray.index(["foo bar bar baz"]),
'body': lambda: SearchArray.index(["foo"])
},
"expected": [lambda frame: sum([0.1 * frame['title'].array.score("foo")[0],
frame['body'].array.score("foo")[0]])],
"params": {'q': "foo",
'qf': ["title", "body"],
'tie': 0.1}
},
"different_analyzers": {
"frame": {
'title': lambda: SearchArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"]),
Expand Down Expand Up @@ -225,6 +250,9 @@ def always_tiny_similarity(term_freqs: np.ndarray, doc_freqs: np.ndarray,

@w_scenarios(edismax_scenarios)
def test_edismax_custom_similarity(frame, expected, params):
if 'tie' in params:
return

frame = build_df(frame)
expected = list(compute_expected(expected, frame))
params['similarity'] = always_one_similarity
Expand All @@ -234,6 +262,9 @@ def test_edismax_custom_similarity(frame, expected, params):

@w_scenarios(edismax_scenarios)
def test_edismax_custom_similarity_per_field(frame, expected, params):
if 'tie' in params:
return

frame = build_df(frame)
expected = list(compute_expected(expected, frame))
params['similarity'] = {"title": always_one_similarity,
Expand Down

0 comments on commit 9bf53d9

Please sign in to comment.