Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow a callable as the input sets. #6

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ python:
- "3.4"
- "3.5"
- "3.6"
- "nightly"
- "3.7"
- "3.8"
install:
- pip install -e .
script:
Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ meet the similarity threshold.
```python
from SetSimilaritySearch import all_pairs

# The input sets must be a Python list of iterables (i.e., lists or sets).
# The input sets must be a Python list of iterables (i.e., lists or sets)
# or a callable (e.g. a function) that returns an iterator of such iterables.
sets = [[1,2,3], [3,4,5], [2,3,4], [5,6,7]]
# all_pairs returns an iterable of tuples.
pairs = all_pairs(sets, similarity_func_name="jaccard",
Expand All @@ -121,7 +122,8 @@ supports a static collection of sets with no updates.
```python
from SetSimilaritySearch import SearchIndex

# The input sets must be a Python list of iterables (i.e., lists or sets).
# The input sets must be a Python list of iterables (i.e., lists or sets)
# or a callable (e.g. a function) that returns an iterator of such iterables.
sets = [[1,2,3], [3,4,5], [2,3,4], [5,6,7]]
# The search index cannot be updated.
index = SearchIndex(sets, similarity_func_name="jaccard",
Expand Down
10 changes: 5 additions & 5 deletions SetSimilaritySearch/all_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ def all_pairs(sets, similarity_func_name="jaccard",
position filter enhancement.

Args:
sets (list): a list of sets, each entry is an iterable representing a
set.
sets (list or callable): a list of sets or a callable that returns an
iterator of sets. Each entry of the list or the returned iterator
is an iterable representing a set. Note that an iterator cannot
be accepted here because `sets` must be scaned twice.
similarity_func_name (str): the name of the similarity function used;
this function currently supports `"jaccard"` and `"cosine"`.
similarity_threshold (float): the threshold used, must be a float
between 0 and 1.0.

Returns:
pairs (Iterator[tuple]): an iterator of tuples `(x, y, similarity)`
where `x` and `y` are the indices of sets in the input list `sets`.
where `x` and `y` are the indices of sets in the input `sets`.
"""
if not isinstance(sets, list) or len(sets) == 0:
raise ValueError("Input parameter sets must be a non-empty list.")
if similarity_func_name not in _similarity_funcs:
raise ValueError("Similarity function {} is not supported.".format(
similarity_func_name))
Expand Down
12 changes: 6 additions & 6 deletions SetSimilaritySearch/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ class SearchIndex(object):
techniques.

Args:
sets (list): a list of sets, each entry is an iterable representing a
set.
sets (list or callable): a list of sets or a callable that returns an
iterator of sets. Each entry of the list or the returned iterator
is an iterable representing a set. Note that an iterator cannot
be accepted here because `sets` must be scaned twice.
similarity_func_name (str): the name of the similarity function used;
this function currently supports `"jaccard"`, `"cosine"`, and
`"containment"`.
Expand All @@ -24,8 +26,6 @@ class SearchIndex(object):

def __init__(self, sets, similarity_func_name="jaccard",
similarity_threshold=0.5):
if not isinstance(sets, list) or len(sets) == 0:
raise ValueError("Input parameter sets must be a non-empty list.")
if similarity_func_name not in _similarity_funcs:
raise ValueError("Similarity function {} is not supported.".format(
similarity_func_name))
Expand All @@ -39,7 +39,7 @@ def __init__(self, sets, similarity_func_name="jaccard",
self.overlap_index_threshold_func = \
_overlap_index_threshold_funcs[similarity_func_name]
self.position_filter_func = _position_filter_funcs[similarity_func_name]
logging.debug("Building SearchIndex on {} sets.".format(len(sets)))
logging.debug("Building SearchIndex...")
logging.debug("Start frequency transform.")
self.sets, self.order = _frequency_order_transform(sets)
logging.debug("Finish frequency transform, {} tokens in total.".format(
Expand All @@ -50,7 +50,7 @@ def __init__(self, sets, similarity_func_name="jaccard",
prefix = self._get_prefix_index(s)
for j, token in enumerate(prefix):
self.index[token].append((i, j))
logging.debug("Finished indexing sets.")
logging.debug("Finished indexing {} sets.".format(len(self.sets)))

def _get_prefix_index(self, s):
t = self.overlap_index_threshold_func(len(s), self.similarity_threshold)
Expand Down
17 changes: 13 additions & 4 deletions SetSimilaritySearch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ def _frequency_order_transform(sets):
in Data Cleaning" by Chaudhuri et al..

Args:
sets (list): a list of sets, each entry is an iterable representing a
set.
sets (list or callable): a list of sets or a callable that returns an
iterator of sets. Each entry of the list or the returned iterator
is an iterable representing a set. Note that an iterator cannot
be accepted here because `sets` must be scaned twice.

Returns:
sets (list): a list of sets, each entry is a sorted Numpy array with
Expand All @@ -106,9 +108,16 @@ def _frequency_order_transform(sets):
in the frequency order.
"""
logging.debug("Applying frequency order transform on tokens...")
counts = reversed(Counter(token for s in sets for token in s).most_common())
if isinstance(sets, list):
get_sets = lambda : sets
elif callable(sets):
get_sets = sets
else:
raise ValueError("sets must be a list or a callable.")
counts = reversed(
Counter(token for s in get_sets() for token in s).most_common())
order = dict((token, i) for i, (token, _) in enumerate(counts))
sets = [np.sort([order[token] for token in s]) for s in sets]
sets = [np.sort([order[token] for token in s]) for s in get_sets()]
logging.debug("Done applying frequency order.")
return sets, order

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# For a discussion on single-sourcing the version across setup.py and the
# project code, see
# https://packaging.python.org/en/latest/single_source_version.html
version='0.1.7', # Required
version='0.1.8', # Required

# This is a one-line description or tagline of what your project does. This
# corresponds to the "Summary" metadata field:
Expand Down
16 changes: 16 additions & 0 deletions tests/all_pairs_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@

class TestAllPairs(unittest.TestCase):

def test_empty(self):
sets = []
pairs = list(all_pairs(sets, similarity_func_name='jaccard',
similarity_threshold=0.1))
self.assertEqual(len(pairs), 0)

def test_jaccard(self):
sets = [[1,2,3], [3,4,5], [2,3,4], [5,6,7]]
correct_pairs = set([(1, 0, 0.2), (2, 0, 0.5), (2, 1, 0.5),
Expand All @@ -16,6 +22,16 @@ def test_jaccard(self):
self.assertTrue(pair in correct_pairs)
self.assertEqual(len(pairs), len(correct_pairs))

def test_callable_input(self):
sets = lambda : [[1,2,3], [3,4,5], [2,3,4], [5,6,7]]
correct_pairs = set([(1, 0, 0.2), (2, 0, 0.5), (2, 1, 0.5),
(3, 1, 0.2)])
pairs = list(all_pairs(sets, similarity_func_name='jaccard',
similarity_threshold=0.1))
for pair in pairs:
self.assertTrue(pair in correct_pairs)
self.assertEqual(len(pairs), len(correct_pairs))

def test_identity_matrix(self):
# Use all-pair to generate an lower-triangular identity matix
nsets = 10
Expand Down
15 changes: 15 additions & 0 deletions tests/search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@

class TestSearchIndex(unittest.TestCase):

def test_empty(self):
sets = []
index = SearchIndex(sets, similarity_func_name="jaccard",
similarity_threshold=0.1)
results = index.query([3,5,4])
self.assertEqual(len(results), 0)

def test_jaccard(self):
sets = [[1,2,3], [3,4,5], [2,3,4], [5,6,7]]
index = SearchIndex(sets, similarity_func_name="jaccard",
Expand All @@ -12,6 +19,14 @@ def test_jaccard(self):
correct_results = set([(1, 1.0), (0, 0.2), (2, 0.5), (3, 0.2)])
self.assertEqual(set(results), correct_results)

def test_callable(self):
sets = lambda : [[1,2,3], [3,4,5], [2,3,4], [5,6,7]]
index = SearchIndex(sets, similarity_func_name="jaccard",
similarity_threshold=0.1)
results = index.query([3,5,4])
correct_results = set([(1, 1.0), (0, 0.2), (2, 0.5), (3, 0.2)])
self.assertEqual(set(results), correct_results)

def test_containment(self):
sets = [[1,2,3], [3,4,5], [2,3,4], [5,6,7]]
# Threshold 0.1
Expand Down