diff --git a/searcharray/solr.py b/searcharray/solr.py index 6af7a0f..cbc06bc 100644 --- a/searcharray/solr.py +++ b/searcharray/solr.py @@ -107,8 +107,11 @@ def edismax(frame: pd.DataFrame, The search results """ terms = q.split() - query_fields = parse_field_boosts(qf) - phrase_fields = parse_field_boosts(pf) if pf else {} + + def listify(x): + return x if isinstance(x, list) else [x] + query_fields = parse_field_boosts(listify(qf)) + phrase_fields = parse_field_boosts(listify(pf)) if pf else {} # bigram_fields = parse_field_boosts(pf2) if pf2 else {} # trigram_fields = parse_field_boosts(pf3) if pf3 else {} diff --git a/test/test_solr.py b/test/test_solr.py index 9792d29..908b63b 100644 --- a/test/test_solr.py +++ b/test/test_solr.py @@ -83,6 +83,21 @@ def test_complex_conditional_spec_with_percentage(): 0], "params": {'q': "foo bar", 'qf': ["title", "body"]}, }, + "pf_title": { + "frame": { + 'title': lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"]), + 'body': lambda: PostingsArray.index(["buzz", "data2", "data3 bar", "bunny funny wunny"]) + }, + "expected": [lambda frame: sum([frame['title'].array.bm25(["foo", "bar"])[0], + frame['title'].array.bm25("foo")[0], + frame['title'].array.bm25("bar")[0]]), + 0, + lambda frame: max(frame['title'].array.bm25("bar")[2], + frame['body'].array.bm25("bar")[2]), + 0], + "params": {'q': "foo bar", 'qf': ["title", "body"], + 'pf': ["title"]} + }, }