Skip to content

Commit

Permalink
Attempt to make terms lazy
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Jun 9, 2024
1 parent 0538075 commit 63ee287
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 18 deletions.
90 changes: 73 additions & 17 deletions searcharray/postings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections import Counter
import warnings
import logging
from typing import List, Union, Optional, Iterable
from typing import List, Union, Optional, Iterable, Iterator, Any


import numpy as np
Expand Down Expand Up @@ -141,11 +141,67 @@ def __hash__(self):
return hash(json.dumps(self.postings, sort_keys=True))


class LazyTerms:
"""Implements a view to a doc in the postings, but only
fetches the postings when needed."""

def __init__(self, doc_id=-1, posns=None, terms=None):
self.posns = posns
self.doc_id = doc_id
self.terms = terms
if self.terms is None:
self.terms = []

def __eq__(self, other):
# Flip to the other implementation if we're comparing to a SearchArray
# to get a boolean array back
if isinstance(other, SearchArray):
return other == self
return isinstance(other, LazyTerms) and self.terms.cols == other.terms.cols

def __len__(self):
return len(self.terms)

def __repr__(self):
return f"LazyTerms(doc_id={self.doc_id})"

def __str__(self):
return f"LazyTerms(doc_id={self.doc_id})"

def __lt__(self, other):
return self.doc_id < other.doc_id

def __le__(self, other):
return self.doc_id < other.doc_id or self.doc_id == other.doc_id

def __gt__(self, other):
return self.doc_id > other.doc_id

def __hash__(self):
return hash(str(self.doc_id))

def raw_positions(self, term_dict, term=None):
tfs = {}
posns = {}
for term_idx in self.terms:
tfs[term] = 1
enc_term_posns = posns.doc_encoded_posns(term_idx, doc_id=self.doc_id)
posns[term] = enc_term_posns

if posns is None:
return {}
if term is None:
raw_posns = [(term_dict.get_term_id(term), posns) for term, posns in posns.items()]
else:
raw_posns = [(term_dict.get_term_id(term), posns[term])]
return raw_posns


class TermsDtype(ExtensionDtype):
"""Pandas dtype for terms."""

name = 'tokenized_text'
type = Terms
type = LazyTerms
kind = 'O'

@classmethod
Expand All @@ -170,10 +226,10 @@ def __repr__(self):

@property
def na_value(self):
return Terms({})
return LazyTerms()

def valid_value(self, value):
return isinstance(value, dict) or pd.isna(value) or isinstance(value, Terms)
return isinstance(value, dict) or pd.isna(value) or isinstance(value, LazyTerms)


register_extension_dtype(TermsDtype)
Expand Down Expand Up @@ -219,7 +275,7 @@ def __init__(self, postings, tokenizer=ws_tokenizer, avoid_copies=True):
self.tokenizer = tokenizer
self.term_mat, self.posns, \
self.term_dict, self.avg_doc_length, \
self.doc_lens = build_index_from_terms_list(postings, Terms)
self.doc_lens = build_index_from_terms_list(postings, LazyTerms)
self.corpus_size = len(self.doc_lens)

@classmethod
Expand Down Expand Up @@ -258,6 +314,7 @@ def index(cls, array: Iterable,
build_index_from_tokenizer(array, tokenizer, batch_size=batch_size,
truncate=truncate,
workers=workers)
import pdb; pdb.set_trace()

if autowarm:
posns.warm()
Expand Down Expand Up @@ -303,13 +360,11 @@ def __getitem__(self, key):
# Want to take rows of term freqs
if isinstance(key, numbers.Integral):
try:
rows = self.term_mat[key]
doc_len = self.doc_lens[key]
# rows = self.term_mat[key]
doc_id = key
if doc_id < 0:
doc_id += len(self)
return _row_to_postings_row(doc_id, rows[0], doc_len,
self.term_dict, self.posns)
return LazyTerms(doc_id, self.posns, self.term_mat[key])
# return _row_to_postings_row(doc_id, rows[0], doc_len,
# self.term_dict, self.posns)
except IndexError:
raise IndexError("index out of bounds")
else:
Expand Down Expand Up @@ -356,7 +411,7 @@ def __setitem__(self, key, value):
if isinstance(value, float):
term_mat = np.asarray([value])
doc_lens = np.asarray([0])
elif isinstance(value, Terms):
elif isinstance(value, LazyTerms):
term_mat = np.asarray([value.tf_to_dense(self.term_dict)])
doc_lens = np.asarray([value.doc_len])
is_encoded = value.encoded
Expand Down Expand Up @@ -386,7 +441,7 @@ def _add_new_terms(self, key, value):
warnings.warn(msg)

scan_value = value
if isinstance(value, Terms):
if isinstance(value, LazyTerms):
scan_value = np.asarray([value])
for row in scan_value:
for term in row.terms():
Expand All @@ -401,8 +456,9 @@ def value_counts(
dropna: bool = True,
):
if dropna:
import pdb; pdb.set_trace()
counts = Counter(self[:])
counts.pop(Terms({}), None)
counts.pop(LazyTerms(), None)
else:
counts = Counter(self[:])
return pd.Series(counts)
Expand Down Expand Up @@ -439,7 +495,7 @@ def __eq__(self, other):
# return np.array(self[:]) == np.array(other[:])

# When other is a scalar value
elif isinstance(other, Terms):
elif isinstance(other, LazyTerms):
other = SearchArray([other], tokenizer=self.tokenizer)
warnings.warn("Comparing a scalar value to a SearchArray. This is slow.")
return np.array(self[:]) == np.array(other[:])
Expand Down Expand Up @@ -477,7 +533,7 @@ def take(self, indices, allow_fill=False, fill_value=None):

if allow_fill and -1 in result_indices:
if fill_value is None or pd.isna(fill_value):
fill_value = Terms({}, encoded=True)
fill_value = LazyTerms()

to_fill_mask = result_indices == -1
# This is slow as it rebuilds all the term dictionaries
Expand Down Expand Up @@ -518,7 +574,7 @@ def _from_factorized(cls, values, original):
def _values_for_factorize(self):
"""Return an array and missing value suitable for factorization (ie grouping)."""
arr = np.asarray(self[:], dtype=object)
return arr, Terms({})
return arr, LazyTerms()

def _check_token_arg(self, token):
if isinstance(token, str):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
# For a discussion on single-sourcing the version across setup.py and the
# project code, see
# https://packaging.python.org/guides/single-sourcing-package-version/
version="0.0.56", # Required
version="0.0.57", # Required
# This is a one-line description or tagline of what your project does. This
# corresponds to the "Summary" metadata field:
# https://packaging.python.org/specifications/core-metadata/#summary
Expand Down
10 changes: 10 additions & 0 deletions test/test_tmdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,13 @@ def test_eq_benchmark(benchmark, tmdb_data):
assert np.sum(results) == compare_amount

# eq = benchmark(tmdb_data['overview_tokens'].array.__eq__, idx_again)


@pytest.mark.skipif(not profile_enabled, reason="Profiling disabled")
def test_iterrows_benchmark(benchmark, tmdb_data):
prof = Profiler(benchmark)

def loop_over():
for idx, row in tmdb_data.iterrows():
pass
prof.run(loop_over)

0 comments on commit 63ee287

Please sign in to comment.