diff --git a/src/dawg.pyx b/src/dawg.pyx index 133c7fc..de0d7f5 100644 --- a/src/dawg.pyx +++ b/src/dawg.pyx @@ -86,7 +86,7 @@ cdef class DAWG: cpdef bint b_has_key(self, bytes key) except -1: return self.dct.Contains(key, len(key)) - cpdef bytes tobytes(self): + cpdef bytes tobytes(self) except +: """ Return raw DAWG content as bytes. """ @@ -249,6 +249,24 @@ cdef class DAWG: return res + def longest_prefix(self, unicode key): + cdef BaseType index = self.dct.root() + cdef int pos = 1 + cdef int lastpos = 0 + cdef CharType ch + + for ch in key: + if not self.dct.Follow(ch, &index): + break + if self.dct.has_value(index): + lastpos = pos + pos += 1 + + if lastpos: + return key[:lastpos] + else: + raise KeyError("No prefix found") + def iterprefixes(self, unicode key): ''' Return a generator with keys of this DAWG that are prefixes of the ``key``. @@ -798,7 +816,28 @@ cdef class BytesDAWG(CompletionDAWG): """ return self._similar_item_values(0, key, self.dct.root(), replaces) + def longest_prefix(self, unicode key): + cdef BaseType index = self.dct.root() + cdef BaseType tmp + cdef BaseType lastindex + cdef int pos = 1 + cdef int lastpos = 0 + cdef CharType ch + + for ch in key: + if not self.dct.Follow(ch, &index): + break + + tmp = index + if self.dct.Follow(self._c_payload_separator, &tmp): + lastpos = pos + lastindex = tmp + pos += 1 + if lastpos: + return key[:lastpos], self._value_for_index(lastindex) + else: + raise KeyError("No prefix found") cdef class RecordDAWG(BytesDAWG): """ @@ -900,6 +939,26 @@ cdef class IntDAWG(DAWG): cpdef int b_get_value(self, bytes key): return self.dct.Find(key) + def longest_prefix(self, unicode key): + cdef BaseType index = self.dct.root() + cdef BaseType lastindex + cdef int pos = 1 + cdef int lastpos = 0 + cdef CharType ch + + for ch in key: + if not self.dct.Follow(ch, &index): + break + + if self.dct.has_value(index): + lastpos = pos + lastindex = index + pos += 1 + + if lastpos: + return key[:lastpos], self.dct.value(lastindex) + else: + raise KeyError("No prefix found") # FIXME: code duplication. cdef class IntCompletionDAWG(CompletionDAWG): diff --git a/tests/test_dawg.py b/tests/test_dawg.py index 26f6627..92baa4b 100644 --- a/tests/test_dawg.py +++ b/tests/test_dawg.py @@ -83,7 +83,12 @@ def test_unicode_sorting(self): # if data is sorted according to unicode rules. dawg.DAWG([key1, key2]) - + def test_longest_prefix(self): + d = dawg.DAWG(["a", "as", "asdf"]) + assert d.longest_prefix("a") == "a" + assert d.longest_prefix("as") == "as" + assert d.longest_prefix("asd") == "as" + assert d.longest_prefix("asdf") == "asdf" class TestIntDAWG(object): @@ -148,6 +153,13 @@ def test_int_value_ranges(self): with pytest.raises(OverflowError): self.IntDAWG({'f': 2**32-1}) + def test_longest_prefix(self): + d = dawg.IntDAWG([("a", 1), ("as", 2), ("asdf", 3)]) + assert d.longest_prefix("a") == ("a", 1) + assert d.longest_prefix("as") == ("as", 2) + assert d.longest_prefix("asd") == ("as", 2) + assert d.longest_prefix("asdf") == ("asdf", 3) + class TestIntCompletionDAWG(TestIntDAWG): IntDAWG = dawg.IntCompletionDAWG # checks that all tests for IntDAWG pass diff --git a/tests/test_payload_dawg.py b/tests/test_payload_dawg.py index 305ac3f..a8b2b1b 100644 --- a/tests/test_payload_dawg.py +++ b/tests/test_payload_dawg.py @@ -83,7 +83,12 @@ def test_build_error(self): with pytest.raises(dawg.Error): self.dawg(payload_separator=b'f') - + def test_longest_prefix(self): + d = dawg.BytesDAWG([("a", b"a1"), ("a", b"a2"), ("as", b"as"), ("asdf", b"asdf")]) + assert d.longest_prefix("a") == ("a", [b"a1", b"a2"]) + assert d.longest_prefix("as") == ("as", [b"as"]) + assert d.longest_prefix("asd") == ("as", [b"as"]) + assert d.longest_prefix("asdf") == ("asdf", [b"asdf"]) class TestRecordDAWG(object):