Skip to content

Commit

Permalink
Fix array concat metadata tracking bug
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Jul 12, 2024
1 parent 156b65f commit 26ae9a1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
6 changes: 5 additions & 1 deletion searcharray/phrase/memmap_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,14 @@ def concat(lhs: 'ArrayDict', rhs: 'ArrayDict', sort=True):
if key not in fetched_keys:
curr_offset += value.size
lst_of_arrays.append(value)
metadata[key] = {'offset': last_offset, 'length': curr_offset}
metadata[key] = {'offset': last_offset, 'length': curr_offset - last_offset}
last_offset = curr_offset

arr.metadata = metadata
for curr_arr in lst_of_arrays:
if sort:
curr_arr.sort()

arr.data = np.concatenate(lst_of_arrays)
return arr

Expand Down
35 changes: 24 additions & 11 deletions test/test_tmdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,6 @@ def ensure_data_dir_exists():
pass


@pytest.fixture(scope="session", autouse=True)
def clean_up():
ensure_data_dir_exists()
yield
clean_data_dir()


@pytest.fixture(scope="session")
def tmdb_raw_data():
path = 'fixtures/tmdb.json.gz'
Expand Down Expand Up @@ -72,9 +65,11 @@ def tmdb_pd_data(tmdb_raw_data):
return df


@pytest.fixture(scope="module", params=["full", "ends_empty", "memmap", "small_batch",
"smallbatch_memmap"])
@pytest.fixture(scope="session", params=["full", "ends_empty", "memmap", "small_batch",
"smallbatch_memmap"])
def tmdb_data(tmdb_pd_data, request):
ensure_data_dir_exists()
print(f"Rebuilding index with {request.param}")
df = tmdb_pd_data
indexed = SearchArray.index(df['title'],
batch_size=5000 if request.param in ["small_batch", "smallbatch_memmap"] else 100000,
Expand All @@ -90,7 +85,8 @@ def tmdb_data(tmdb_pd_data, request):
batch_size=5000 if request.param in ["small_batch", "smallbatch_memmap"] else 100000,
data_dir=DATA_DIR if request.param == "memmap" else None)
df['overview_tokens'] = indexed
return df
yield df
clean_data_dir()


def test_tokenize_tmdb(tmdb_raw_data):
Expand Down Expand Up @@ -187,19 +183,36 @@ def test_tmdb_expected_edismax(query, tmdb_data):
title_has_term = np.sum([naive_find_term(tmdb_data['title'],
query_term,
title_tokenizer) for query_term in title_tokenizer(query)], axis=0) > 0
tmdb_data['title_has_term'] = title_has_term
overview_has_term = np.sum([naive_find_term(tmdb_data['overview'],
query_term,
overview_tokenizer) for query_term in overview_tokenizer(query)], axis=0) > 0
tmdb_data['overview_has_term'] = overview_has_term
tmdb_data['sum_has_term'] = title_has_term + overview_has_term

title_has_term2 = np.sum([naive_find_term(tmdb_data['title'],
query_term,
title_tokenizer) for query_term in title_tokenizer(query)], axis=0) > 0
tmdb_data['title_has_term2'] = title_has_term2
overview_has_term2 = np.sum([naive_find_term(tmdb_data['overview'],
query_term,
overview_tokenizer) for query_term in overview_tokenizer(query)], axis=0) > 0
tmdb_data['overview_has_term'] = overview_has_term2
tmdb_data['sum_has_term2'] = title_has_term + overview_has_term
assert np.all(title_has_term == title_has_term2)
assert np.all(overview_has_term == overview_has_term2)
assert np.all(tmdb_data['sum_has_term'] == tmdb_data['sum_has_term2'])

matches, _ = edismax(tmdb_data, q=query,
qf=["title_tokens^2", "overview_tokens"],
pf=["title_tokens^2", "overview_tokens"],
pf2=["title_tokens^2", "overview_tokens"],
tie=0.1,
mm=1)
matches = tmdb_data[matches > 0]
tmdb_data['matches'] = matches
expected_matches = tmdb_data[title_has_term | overview_has_term].index
print(f"Query - {query} | Expected: {len(expected_matches)}")
matches = tmdb_data[matches > 0]
assert np.all(matches.index == expected_matches)


Expand Down

0 comments on commit 26ae9a1

Please sign in to comment.