Skip to content

Commit

Permalink
Cleanup type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Dec 24, 2023
1 parent 39c56bf commit f69beb2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
29 changes: 15 additions & 14 deletions searcharray/postings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
import logging
from time import perf_counter
from typing import List, Union


import numpy as np
Expand Down Expand Up @@ -594,7 +595,7 @@ def _check_token_arg(self, token):
# ***********************************************************
# Naive implementations of search functions to clean up later
# ***********************************************************
def term_freq(self, token):
def term_freq(self, token: Union[List[str], str]) -> np.ndarray:
token = self._check_token_arg(token)
if isinstance(token, list):
return self.phrase_freq(token)
Expand All @@ -610,17 +611,17 @@ def term_freq(self, token):
except TermMissingError:
return np.zeros(len(self), dtype=int)

def doc_freq(self, token):
def doc_freq(self, token: str) -> int:
if not isinstance(token, str):
raise TypeError("Expected a string")
# Count number of rows where the term appears
term_freq = self.term_freq(token)
return np.sum(term_freq > 0)

def doc_lengths(self):
def doc_lengths(self) -> np.ndarray:
return self.doc_lens

def match(self, token, slop=1):
def match(self, token: Union[List[str], str], slop: int = 1) -> np.ndarray:
"""Return a boolean numpy array indicating which elements contain the given term."""
token = self._check_token_arg(token)
if isinstance(token, list):
Expand All @@ -629,7 +630,7 @@ def match(self, token, slop=1):
term_freq = self.term_freq(token)
return term_freq > 0

def bm25_idf(self, token, doc_stats=None):
def bm25_idf(self, token: Union[List[str], str], doc_stats=None) -> float:
"""Calculate the (Lucene) idf for a term.
idf, computed as log(1 + (N - n + 0.5) / (n + 0.5))
Expand All @@ -642,24 +643,24 @@ def bm25_idf(self, token, doc_stats=None):
num_docs = len(self)
return np.log(1 + (num_docs - df + 0.5) / (df + 0.5))

def bm25_phrase_idf(self, tokens):
def bm25_phrase_idf(self, tokens: List[str]) -> float:
"""Calculate the idf for a phrase.
This is the sum of the idfs of the individual terms.
"""
idfs = [self.bm25_idf(term) for term in tokens]
return np.sum(idfs)

def bm25_tf(self, token, k1=1.2, b=0.75, slop=1):
"""Calculate the (Lucene) BM25 tf for a term.
def bm25_tf(self, token: Union[List[str], str], k1=1.2, b=0.75, slop=1) -> float:
"""Calculate the (Lucene) BM25 tf for a term or phrase.
tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl))
"""
tf = self.term_freq(token)
score = tf / (tf + k1 * (1 - b + b * self.doc_lengths() / self.avg_doc_length))
return score

def bm25(self, token, doc_stats=None, k1=1.2, b=0.75):
def bm25(self, token: Union[str, List[str]], doc_stats=None, k1=1.2, b=0.75):
"""Score each doc using BM25.
Parameters
Expand All @@ -673,24 +674,24 @@ def bm25(self, token, doc_stats=None, k1=1.2, b=0.75):
token = self._check_token_arg(token)
return self.bm25_idf(token, doc_stats=doc_stats) * self.bm25_tf(token)

def positions(self, token, key=None):
def positions(self, token: str, key=None) -> List[np.ndarray]:
"""Return a list of lists of positions of the given term."""
term_id = self.term_dict.get_term_id(token)
posns = self.posns.positions(term_id, key=key)
return posns

def and_query(self, tokens):
def and_query(self, tokens: List[str]) -> np.ndarray:
"""Return a mask on the postings array indicating which elements contain all terms."""
masks = [self.match(term) for term in tokens]
mask = np.ones(len(self), dtype=bool)
for curr_mask in masks:
mask = mask & curr_mask
return mask

def phrase_freq(self, tokens, slop=1):
def phrase_freq(self, tokens: List[str], slop=1) -> np.ndarray:
return self.phrase_freq_every_diff(tokens, slop=slop)

def phrase_freq_scan(self, tokens, mask=None, slop=1):
def phrase_freq_scan(self, tokens: List[str], mask=None, slop=1) -> np.ndarray:
if mask is None:
mask = self.and_query(tokens)

Expand All @@ -704,7 +705,7 @@ def phrase_freq_scan(self, tokens, mask=None, slop=1):
phrase_freqs[mask] = scan_merge_ins(posns, phrase_freqs[mask], slop=slop)
return phrase_freqs

def phrase_freq_every_diff(self, tokens, slop=1):
def phrase_freq_every_diff(self, tokens: List[str], slop=1) -> np.ndarray:
"""Batch up calls to _phrase_freq_every_diff."""
phrase_freqs = -np.ones(len(self))

Expand Down
2 changes: 1 addition & 1 deletion test/test_msmarco.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def ws_punc_tokenizer(text):

# Memory usage
#
#Indexed in 14.7362s
# Indexed in 14.7362s
# [postings.py:303 - _build_index_from_dict() ] Padded Posn memory usage: 4274.036334991455 MB
# [postings.py:304 - _build_index_from_dict() ] Bitwis Posn memory usage: 800.7734680175781 MB

Expand Down

0 comments on commit f69beb2

Please sign in to comment.