diff --git a/test/test_tmdb.py b/test/test_tmdb.py index 63ddf98..c52acba 100644 --- a/test/test_tmdb.py +++ b/test/test_tmdb.py @@ -6,7 +6,8 @@ import numpy as np import sys from searcharray.postings import SearchArray -from test_utils import Profiler, profile_enabled +from searcharray.solr import edismax +from test_utils import Profiler, profile_enabled, naive_find_term should_profile = '--benchmark-disable' in sys.argv @@ -123,6 +124,73 @@ def test_term_freqs(tmdb_data, term, expected_matches): assert np.all(term_freqs == 1) +queries = [ + "Star Wars", + "the next generation", + "bartender fights a cow and", + "to be or not to be", + "the quick brown fox jumps over the lazy dog", + "bill and ted's excellent adventure", + "thirty years after defeating the galactic empire", + "a film about a daughter of a refugee family", + "have one thing in mind: to find a way to kill each other without risk. After listening to a radio show, Paul decided", + "executive who can't stop his career downspiral is invited into his daughter's imaginary world, where solutions to his" +] + + +@pytest.mark.parametrize("query", queries) +def test_tmdb_expected_edismax(query, tmdb_data): + + title_tokenizer = tmdb_data['title_tokens'].array.tokenizer + overview_tokenizer = tmdb_data['overview_tokens'].array.tokenizer + title_has_term = np.sum([naive_find_term(tmdb_data['title'], + query_term, + title_tokenizer) for query_term in title_tokenizer(query)], axis=0) > 0 + overview_has_term = np.sum([naive_find_term(tmdb_data['overview'], + query_term, + overview_tokenizer) for query_term in overview_tokenizer(query)], axis=0) > 0 + + matches, _ = edismax(tmdb_data, q=query, + qf=["title_tokens^2", "overview_tokens"], + pf=["title_tokens^2", "overview_tokens"], + pf2=["title_tokens^2", "overview_tokens"], + tie=0.1, + mm=1) + matches = tmdb_data[matches > 0] + expected_matches = tmdb_data[title_has_term | overview_has_term].index + print(f"Query - {query} | Expected: {len(expected_matches)}") + assert np.all(matches.index == expected_matches) + + +@pytest.mark.parametrize("query", queries) +def test_tmdb_expected_edismax_and_query(query, tmdb_data): + + title_tokenizer = tmdb_data['title_tokens'].array.tokenizer + overview_tokenizer = tmdb_data['overview_tokens'].array.tokenizer + num_terms = len(query.split()) + title_matches = np.asarray([naive_find_term(tmdb_data['title'], + query_term, + title_tokenizer) for query_term in title_tokenizer(query)]) > 0 + overview_matches = np.asarray([naive_find_term(tmdb_data['overview'], + query_term, + overview_tokenizer) for query_term in overview_tokenizer(query)]) > 0 + + either_has_term = (title_matches + overview_matches) + all_terms_have_match = np.sum(either_has_term, axis=0) == num_terms + + matches, _ = edismax(tmdb_data, q=query, + qf=["title_tokens^2", "overview_tokens"], + pf=["title_tokens^2", "overview_tokens"], + pf2=["title_tokens^2", "overview_tokens"], + tie=0.1, + mm="100%") + matches = tmdb_data[matches > 0] + + expected_matches = tmdb_data[all_terms_have_match].index + print(f"Query - {query} | Expected: {len(expected_matches)}") + assert np.all(matches.index == expected_matches) + + tmdb_phrase_matches = [ (["Star", "Wars"], ['11', '330459', '76180']), (["Black", "Mirror:"], ['374430']), diff --git a/test/test_utils.py b/test/test_utils.py index f28be1e..61e073a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,5 +1,6 @@ import pytest import datetime +import pandas as pd from typing import Dict, Any, cast, Sequence, Type, Union import cProfile import sys @@ -46,6 +47,12 @@ def run(self, func, *args, **kwargs): return rval +def naive_find_term(text: pd.Series, term: str, + tokenizer): + text_as_tokens = text.apply(tokenizer) + return text_as_tokens.apply(lambda tokens: term in tokens) + + Profiler: Union[Type[JustBenchmarkProfiler], Type[CProfileProfiler]] if '--benchmark-disable' in sys.argv: