diff --git a/string_grouper/string_grouper.py b/string_grouper/string_grouper.py index d1612511..c670745e 100644 --- a/string_grouper/string_grouper.py +++ b/string_grouper/string_grouper.py @@ -194,6 +194,11 @@ class StringGrouperNotFitException(Exception): pass +class StringLengthException(Exception): + """Raised when vectoriser is fit on strings that are not of length greater than ngram size""" + pass + + class StringGrouper(object): def __init__(self, master: pd.Series, duplicates: Optional[pd.Series] = None, @@ -258,6 +263,13 @@ def n_grams(self, string: str) -> List[str]: def fit(self) -> 'StringGrouper': """Builds the _matches list which contains string matches indices and similarity""" + + # Validate match strings length + if not StringGrouper._strings_are_of_sufficient_length(self._master, self._config.ngram_size) or \ + (self._duplicates is not None + and not StringGrouper._strings_are_of_sufficient_length(self._duplicates, self._config.ngram_size)): + raise StringLengthException('Input string lengths are not all greater than n_gram length') + master_matrix, duplicate_matrix = self._get_tf_idf_matrices() # Calculate the matches using the cosine similarity @@ -697,6 +709,16 @@ def _is_series_of_strings(series_to_test: pd.Series) -> bool: return False return True + @staticmethod + def _strings_are_of_sufficient_length(series_to_test: pd.Series, ngram_size: int) -> bool: + if not isinstance(series_to_test, pd.Series): + return False + elif series_to_test.to_frame().applymap( + lambda x: not len(x) >= ngram_size + ).squeeze(axis=1).all(): + return False + return True + @staticmethod def _is_input_data_combination_valid(duplicates, master_id, duplicates_id) -> bool: if duplicates is None and (duplicates_id is not None) \ diff --git a/string_grouper/test/test_string_grouper.py b/string_grouper/test/test_string_grouper.py index f5f0aac8..8ebf6fcf 100644 --- a/string_grouper/test/test_string_grouper.py +++ b/string_grouper/test/test_string_grouper.py @@ -6,7 +6,7 @@ DEFAULT_REGEX, DEFAULT_NGRAM_SIZE, DEFAULT_N_PROCESSES, DEFAULT_IGNORE_CASE, \ StringGrouperConfig, StringGrouper, StringGrouperNotFitException, \ match_most_similar, group_similar_strings, match_strings, \ - compute_pairwise_similarities + compute_pairwise_similarities, StringLengthException from unittest.mock import patch @@ -822,6 +822,11 @@ def test_prior_matches_added(self): # All strings should now match to the same "master" string self.assertEqual(1, len(df.deduped.unique())) + def test_group_similar_strings_stopwords(self): + """StringGrouper shouldn't raise a ValueError if all strings are shorter than 3 characters""" + with self.assertRaises(StringLengthException): + StringGrouper(pd.Series(['zz', 'yy', 'xx'])).fit() + if __name__ == '__main__': unittest.main()