Skip to content

Commit

Permalink
Create custom matrix to remove scipy
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Dec 23, 2023
1 parent 2ab96c7 commit 580db64
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 54 deletions.
77 changes: 46 additions & 31 deletions searcharray/postings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Doc,Term -> freq
# Note scipy sparse switching to *_array, which is more numpy like
# However, as of now, these don't seem fully baked
from scipy.sparse import dok_matrix
from searcharray.utils.mat_set import SparseMatSet

logger = logging.getLogger(__name__)

Expand All @@ -44,10 +44,15 @@ class PostingsRow:
https://github.com/pandas-dev/pandas/issues/17777
"""

def __init__(self, postings, posns: dict = None, encoded=False):
def __init__(self,
postings,
doc_len: int = 0,
posns: dict = None,
encoded=False):
self.postings = postings
self.posns = None
self.encoded = encoded
self.doc_len = doc_len

if posns is not None:
for term, term_posns in posns.items():
Expand Down Expand Up @@ -115,7 +120,9 @@ def __eq__(self, other):
# to get a boolean array back
if isinstance(other, PostingsArray):
return other == self
return isinstance(other, PostingsRow) and self.postings == other.postings
same_postings = isinstance(other, PostingsRow) and self.postings == other.postings
if same_postings and self.doc_len == other.doc_len:
return True

def __lt__(self, other):
# return isinstance(other, PostingsRow) and hash(self) < hash(other)
Expand Down Expand Up @@ -207,7 +214,8 @@ def _build_index_from_dict(postings):
"""Bulid an index from postings that are already tokenized and point at their term frequencies."""
start = perf_counter()
term_dict = TermDict()
freqs_table = defaultdict(int)
term_doc = SparseMatSet()
doc_lens = []
avg_doc_length = 0
num_postings = 0
add_term_time = 0
Expand All @@ -228,22 +236,24 @@ def _build_index_from_dict(postings):
# this is faster that directly using the matrix
# https://www.austintripp.ca/blog/2018/09/12/sparse-matrices-tips1
for doc_id, tokenized in enumerate(postings):
term_doc.ensure_capacity(doc_id)
if isinstance(tokenized, dict):
tokenized = PostingsRow(tokenized)
tokenized = PostingsRow(tokenized, doc_len=len(tokenized))
elif not isinstance(tokenized, PostingsRow):
raise TypeError("Expected a PostingsRow or a dict")

if tokenized.encoded:
posns = posns_enc

avg_doc_length += len(tokenized)
doc_lens.append(tokenized.doc_len)
avg_doc_length += doc_lens[-1]
for token, term_freq in tokenized.terms():
add_term_start = perf_counter()
term_id = term_dict.add_term(token)
add_term_time += perf_counter() - add_term_start

set_time_start = perf_counter()
freqs_table[doc_id, term_id] += term_freq
term_doc[doc_id, term_id] = 1
set_time += perf_counter() - set_time_start

get_posns_start = perf_counter()
Expand All @@ -270,31 +280,23 @@ def _build_index_from_dict(postings):

logger.debug(f"Indexed {num_postings} documents in {perf_counter() - start} seconds")

# COPY 2
freqs_dok = dok_matrix((num_postings, len(term_dict)), dtype=np.uint32)
dict.update(freqs_dok, freqs_table)
logger.debug(f"DOK 1 took {perf_counter() - start} seconds to build")

freqs_csr = freqs_dok.tocsr()
logger.debug(f"CSR 1 took {perf_counter() - start} seconds to build")

bit_posns = posns.build()
logger.info(f"Bitwis Posn memory usage: {bit_posns.nbytes / 1024 / 1024} MB")

return RowViewableMatrix(freqs_csr), bit_posns, term_dict, avg_doc_length
return RowViewableMatrix(term_doc), bit_posns, term_dict, avg_doc_length, np.array(doc_lens)


def _row_to_postings_row(row, term_dict, posns: PosnBitArray):
def _row_to_postings_row(doc_id, row, doc_len, term_dict, posns: PosnBitArray):
tfs = {}
non_zeros = row.nonzero()
labeled_posns = {}
for row_idx, term_idx in zip(non_zeros[0], non_zeros[1]):
for term_idx in row.cols:
term = term_dict.get_term(term_idx)
tfs[term] = int(row[row_idx, term_idx])
enc_term_posns = posns.doc_encoded_posns(term_idx, doc_id=row_idx)
tfs[term] = 1
enc_term_posns = posns.doc_encoded_posns(term_idx, doc_id=doc_id)
labeled_posns[term] = enc_term_posns

result = PostingsRow(tfs, labeled_posns, encoded=True)
result = PostingsRow(tfs, posns=labeled_posns,
doc_len=doc_len, encoded=True)
# TODO add positions
return result

Expand All @@ -318,7 +320,8 @@ def __init__(self, postings, tokenizer=ws_tokenizer):

self.tokenizer = tokenizer
self.term_freqs, self.posns, \
self.term_dict, self.avg_doc_length = _build_index_from_dict(postings)
self.term_dict, self.avg_doc_length, \
self.doc_lens = _build_index_from_dict(postings)

@classmethod
def index(cls, array, tokenizer=ws_tokenizer):
Expand All @@ -339,7 +342,9 @@ def tokenized_docs(docs):
positions = defaultdict(list)
for posn in range(len(token_stream)):
positions[token_stream[posn]].append(posn)
yield PostingsRow(term_freqs, positions)
yield PostingsRow(term_freqs,
doc_len=len(token_stream),
posns=positions)

return cls(tokenized_docs(array), tokenizer)

Expand All @@ -364,15 +369,20 @@ def memory_usage(self, deep=False):

@property
def nbytes(self):
return self.term_freqs.nbytes + self.posns.nbytes
return self.term_freqs.nbytes + self.posns.nbytes + self.doc_lens.nbytes

def __getitem__(self, key):
key = pd.api.indexers.check_array_indexer(self, key)
# Want to take rows of term freqs
if isinstance(key, numbers.Integral):
try:
rows = self.term_freqs[key]
return _row_to_postings_row(rows[0], self.term_dict, self.posns)
doc_len = self.doc_lens[key]
doc_id = key
if doc_id < 0:
doc_id += len(self)
return _row_to_postings_row(doc_id, rows[0], doc_len,
self.term_dict, self.posns)
except IndexError:
raise IndexError("index out of bounds")
else:
Expand All @@ -381,6 +391,7 @@ def __getitem__(self, key):
sliced_posns = self.posns.slice(key)
arr = PostingsArray([], tokenizer=self.tokenizer)
arr.term_freqs = sliced_tfs
arr.doc_lens = self.doc_lens[key]
arr.posns = sliced_posns
arr.term_dict = self.term_dict
arr.avg_doc_length = self.avg_doc_length
Expand Down Expand Up @@ -409,18 +420,23 @@ def __setitem__(self, key, value):
is_encoded = False
posns = None
term_freqs = np.asarray([])
doc_lens = np.asarray([])
if isinstance(value, float):
term_freqs = np.asarray([value])
doc_lens = np.asarray([0])
elif isinstance(value, PostingsRow):
term_freqs = np.asarray([value.tf_to_dense(self.term_dict)])
doc_lens = np.asarray([value.doc_len])
is_encoded = value.encoded
posns = [value.raw_positions(self.term_dict)]
elif isinstance(value, np.ndarray):
term_freqs = np.asarray([x.tf_to_dense(self.term_dict) for x in value])
doc_lens = np.asarray([x.doc_len for x in value])
is_encoded = value[0].encoded if len(value) > 0 else False
posns = [x.raw_positions(self.term_dict) for x in value]
np.nan_to_num(term_freqs, copy=False, nan=0)
self.term_freqs[key] = term_freqs
self.doc_lens[key] = doc_lens

if posns is not None:
self.posns.insert(key, posns, is_encoded)
Expand Down Expand Up @@ -485,7 +501,7 @@ def __eq__(self, other):
# Compatible term dicts, and same term freqs
# (not looking at positions, maybe we should?)
if self.term_dict.compatible(other.term_dict):
return self.term_freqs == other.term_freqs
return (self.term_freqs == other.term_freqs) & (self.doc_lens == other.doc_lens)
else:
return np.zeros(len(self), dtype=bool)
# return np.array(self[:]) == np.array(other[:])
Expand Down Expand Up @@ -513,9 +529,7 @@ def __eq__(self, other):

def isna(self):
# Every row with all 0s
key_slice_all = slice(None)
sliced = self.term_freqs.slice(key_slice_all)
empties = np.asarray((sliced.sum(axis=1) == 0).flatten())[0]
empties = self.doc_lens == 0
return empties

def take(self, indices, allow_fill=False, fill_value=None):
Expand Down Expand Up @@ -543,6 +557,7 @@ def take(self, indices, allow_fill=False, fill_value=None):

def copy(self):
postings_arr = PostingsArray([], tokenizer=self.tokenizer)
postings_arr.doc_lens = self.doc_lens.copy()
postings_arr.posns = self.posns.copy()
postings_arr.term_freqs = self.term_freqs.copy()
postings_arr.term_dict = self.term_dict.copy()
Expand Down Expand Up @@ -601,7 +616,7 @@ def doc_freq(self, token):
return np.sum(term_freq > 0)

def doc_lengths(self):
return np.array(self.term_freqs.sum(axis=1).flatten())[0]
return self.doc_lens

def match(self, token, slop=1):
"""Return a boolean numpy array indicating which elements contain the given term."""
Expand Down
128 changes: 128 additions & 0 deletions searcharray/utils/mat_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import numpy as np
import pandas as pd
import numbers


class SparseMatSet:
"""Sparse matrix that only stores the set of row/col indices that are set to 1."""

def __init__(self, cols=None, rows=None):
if rows is None:
rows = np.asarray([0], dtype=np.uint32)
cols = np.asarray([], dtype=np.uint32)
self.cols = cols.astype(np.uint32) # col indices.
self.rows = rows.astype(np.uint32) # indices into cols
assert self.rows[-1] == len(self.cols)

def __getitem__(self, key):
# Iterate keys
beg_keys = self.rows[:-1][key]
end_keys = self.rows[1:][key]

if not isinstance(beg_keys, np.ndarray):
beg_keys = np.asarray([beg_keys])
end_keys = np.asarray([end_keys])

cols = [self.cols[beg:end] for beg, end in zip(beg_keys, end_keys)]
rows = [0] + [len(row) for row in cols]
rows = np.asarray(rows).flatten()
rows = np.cumsum(rows)
try:
cols = np.concatenate(cols)
except ValueError:
cols = np.asarray([], dtype=np.uint32)
return SparseMatSet(cols, rows)

def ensure_capacity(self, row):
if row >= len(self):
append_amt = row - (len(self.rows) - 1) + 1
new_row_ptrs = [len(self.cols)] * append_amt
self.rows = np.concatenate([self.rows, new_row_ptrs])

def _set_cols(self, row, cols, overwrite=False):
row = np.int32(row)
self.ensure_capacity(row)

cols_for_row = self.rows[:-1][row]
cols_for_row_next = self.rows[1:][row]
front_cols = np.asarray([], dtype=np.int64)
trailing_cols = np.asarray([], dtype=np.int64)
if row > 0:
front_cols = self.cols[:cols_for_row]
if cols_for_row_next != self.rows[-1]:
trailing_cols = self.cols[cols_for_row_next:]

existing_set_cols = self.cols[cols_for_row:cols_for_row_next]
cols_added = np.int32(len(np.setdiff1d(cols, existing_set_cols)))
if not overwrite:
existing_set_cols = np.unique(np.concatenate([cols, existing_set_cols]))

self.cols = np.concatenate([front_cols, existing_set_cols, trailing_cols], dtype=np.int64)
else:
cols_added = np.int32(len(cols) - len(existing_set_cols))
self.cols = np.concatenate([front_cols, cols, trailing_cols], dtype=np.int64)

if cols_added < 0:
# TODO some casting nonsense makes this necessary
self.rows[row + 1:] -= np.abs(cols_added)
else:
self.rows[row + 1:] += cols_added

def __setitem__(self, index, value):
if isinstance(index, numbers.Integral):
if len(value.shape) == 1:
value = value.reshape(1, -1)
set_rows, set_cols = value.nonzero()
if not (value[set_rows, set_cols] == 1).all():
raise ValueError("This sparse matrix only supports setting 1")
self._set_cols(index, set_cols, overwrite=True)

# Multidimensional indexing
elif isinstance(index, tuple):
row, col = index
if value != 1:
raise ValueError("This sparse matrix only supports setting 1")
self._set_cols(row, np.asarray([col]))
# Multiple rows
elif pd.api.types.is_list_like(index):
if len(index) == len(value):
for idx, val in zip(index, value):
self[idx] = val
elif len(value) == 1:
for idx in index:
self[idx] = value
else:
raise ValueError("Index and value must be same length")

def copy(self):
return SparseMatSet(self.cols.copy(), self.rows.copy())

@property
def nbytes(self):
return self.cols.nbytes + self.rows.nbytes

@property
def shape(self):
rows = len(self.rows) - 1
cols = 0
if len(self.cols) > 0:
cols = np.max(self.cols)
return (rows, cols)

def num_cols_per_row(self):
return np.diff(self.rows)

def __len__(self):
return len(self.rows) - 1

def __eq__(self, other):
return np.all(self.rows == other.rows) and np.all(self.cols == other.cols)

def __repr__(self):
return f"SparseMatSet(shape={self.shape})"

def __str__(self):
as_str = [""]
for idx, (row, row_next) in enumerate(zip(self.rows, self.rows[1:])):
as_str.append(f"{idx}: {self.cols[row:row_next]}")
return "\n".join(as_str)
Loading

0 comments on commit 580db64

Please sign in to comment.