Skip to content

Commit

Permalink
Add tests for phrases in tmdb
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed May 16, 2024
1 parent c25ac38 commit ea57cd5
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 1 deletion.
70 changes: 69 additions & 1 deletion test/test_tmdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import numpy as np
import sys
from searcharray.postings import SearchArray
from test_utils import Profiler, profile_enabled
from searcharray.solr import edismax
from test_utils import Profiler, profile_enabled, naive_find_term


should_profile = '--benchmark-disable' in sys.argv
Expand Down Expand Up @@ -123,6 +124,73 @@ def test_term_freqs(tmdb_data, term, expected_matches):
assert np.all(term_freqs == 1)


queries = [
"Star Wars",
"the next generation",
"bartender fights a cow and",
"to be or not to be",
"the quick brown fox jumps over the lazy dog",
"bill and ted's excellent adventure",
"thirty years after defeating the galactic empire",
"a film about a daughter of a refugee family",
"have one thing in mind: to find a way to kill each other without risk. After listening to a radio show, Paul decided",
"executive who can't stop his career downspiral is invited into his daughter's imaginary world, where solutions to his"
]


@pytest.mark.parametrize("query", queries)
def test_tmdb_expected_edismax(query, tmdb_data):

title_tokenizer = tmdb_data['title_tokens'].array.tokenizer
overview_tokenizer = tmdb_data['overview_tokens'].array.tokenizer
title_has_term = np.sum([naive_find_term(tmdb_data['title'],
query_term,
title_tokenizer) for query_term in title_tokenizer(query)], axis=0) > 0
overview_has_term = np.sum([naive_find_term(tmdb_data['overview'],
query_term,
overview_tokenizer) for query_term in overview_tokenizer(query)], axis=0) > 0

matches, _ = edismax(tmdb_data, q=query,
qf=["title_tokens^2", "overview_tokens"],
pf=["title_tokens^2", "overview_tokens"],
pf2=["title_tokens^2", "overview_tokens"],
tie=0.1,
mm=1)
matches = tmdb_data[matches > 0]
expected_matches = tmdb_data[title_has_term | overview_has_term].index
print(f"Query - {query} | Expected: {len(expected_matches)}")
assert np.all(matches.index == expected_matches)


@pytest.mark.parametrize("query", queries)
def test_tmdb_expected_edismax_and_query(query, tmdb_data):

title_tokenizer = tmdb_data['title_tokens'].array.tokenizer
overview_tokenizer = tmdb_data['overview_tokens'].array.tokenizer
num_terms = len(query.split())
title_matches = np.asarray([naive_find_term(tmdb_data['title'],
query_term,
title_tokenizer) for query_term in title_tokenizer(query)]) > 0
overview_matches = np.asarray([naive_find_term(tmdb_data['overview'],
query_term,
overview_tokenizer) for query_term in overview_tokenizer(query)]) > 0

either_has_term = (title_matches + overview_matches)
all_terms_have_match = np.sum(either_has_term, axis=0) == num_terms

matches, _ = edismax(tmdb_data, q=query,
qf=["title_tokens^2", "overview_tokens"],
pf=["title_tokens^2", "overview_tokens"],
pf2=["title_tokens^2", "overview_tokens"],
tie=0.1,
mm="100%")
matches = tmdb_data[matches > 0]

expected_matches = tmdb_data[all_terms_have_match].index
print(f"Query - {query} | Expected: {len(expected_matches)}")
assert np.all(matches.index == expected_matches)


tmdb_phrase_matches = [
(["Star", "Wars"], ['11', '330459', '76180']),
(["Black", "Mirror:"], ['374430']),
Expand Down
7 changes: 7 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import datetime
import pandas as pd
from typing import Dict, Any, cast, Sequence, Type, Union
import cProfile
import sys
Expand Down Expand Up @@ -46,6 +47,12 @@ def run(self, func, *args, **kwargs):
return rval


def naive_find_term(text: pd.Series, term: str,
tokenizer):
text_as_tokens = text.apply(tokenizer)
return text_as_tokens.apply(lambda tokens: term in tokens)


Profiler: Union[Type[JustBenchmarkProfiler], Type[CProfileProfiler]]

if '--benchmark-disable' in sys.argv:
Expand Down

0 comments on commit ea57cd5

Please sign in to comment.