Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rangefunc support #2

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module github.com/koron-go/trietree

go 1.21
go 1.23

require github.com/google/go-cmp v0.6.0
36 changes: 36 additions & 0 deletions predict.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package trietree

import (
"iter"
"unicode/utf8"
)

Expand Down Expand Up @@ -134,3 +135,38 @@ func predictIter[T comparable](tree predictableTree[T], query string) func() *Pr
return p
}
}

// Predict returns an iterator which enumerates Prediction: key suggestions
// that match the query in the tree.
func (dt *DTree) Predict(query string) iter.Seq[Prediction] {
return predict[*DNode](dt, query)
}

// Predict returns an iterator which enumerates Prediction: key suggestions
// that match the query in the tree.
func (st *STree) Predict(query string) iter.Seq[Prediction] {
return predict[int](st, query)
}

func predict[T comparable](tree predictableTree[T], query string) iter.Seq[Prediction] {
var zero T
tr := newTraverser[T](tree, query)
return func(yield func(Prediction) bool) {
for {
node, end, valid := tr.next()
if !valid {
return
}
for node != zero {
if id := tree.nodeId(node); id > 0 {
st := trailingIndex(query[:end], tree.nodeLevel(node))
if !yield(Prediction{Start: st, End: end, ID: id}) {
tr.close()
return
}
}
node = tree.nodeFail(node)
}
}
}
}
103 changes: 103 additions & 0 deletions predict_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package trietree_test

import (
"iter"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -137,3 +138,105 @@ func TestSTree_PredictMultiple(t *testing.T) {
{Start: 2, End: 3, ID: 4, Key: "d"},
})
}

type predictor interface {
Predict(string) iter.Seq[trietree.Prediction]
}

func testPredict(t *testing.T, ptor predictor, q string, want []prediction) {
t.Helper()
got := make([]prediction, 0, 10)
for p := range ptor.Predict(q) {
got = append(got, prediction{
Start: p.Start,
End: p.End,
ID: p.ID,
Key: q[p.Start:p.End],
})
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("unexpected predictions: -want +got\n%s", d)
}
}

type predictorBuilder func(t *testing.T, keys ...string) predictor

func testPredictSingle(t *testing.T, build predictorBuilder) {
ptor := build(t, "1", "2", "3", "4", "5")
testPredict(t, ptor, "1", []prediction{
{Start: 0, End: 1, ID: 1, Key: "1"},
})
testPredict(t, ptor, "2", []prediction{
{Start: 0, End: 1, ID: 2, Key: "2"},
})
testPredict(t, ptor, "3", []prediction{
{Start: 0, End: 1, ID: 3, Key: "3"},
})
testPredict(t, ptor, "4", []prediction{
{Start: 0, End: 1, ID: 4, Key: "4"},
})
testPredict(t, ptor, "5", []prediction{
{Start: 0, End: 1, ID: 5, Key: "5"},
})
testPredict(t, ptor, "6", []prediction{})
}

func testPredictMultiple(t *testing.T, build predictorBuilder) {
ptor := build(t, "1", "2", "3", "4", "5")
testPredict(t, ptor, "1234567890", []prediction{
{Start: 0, End: 1, ID: 1, Key: "1"},
{Start: 1, End: 2, ID: 2, Key: "2"},
{Start: 2, End: 3, ID: 3, Key: "3"},
{Start: 3, End: 4, ID: 4, Key: "4"},
{Start: 4, End: 5, ID: 5, Key: "5"},
})
}

func testPredictBasic(t *testing.T, build predictorBuilder) {
ptor := build(t, "ab", "bc", "bab", "d", "abcde")
testPredict(t, ptor, "ab", []prediction{
{Start: 0, End: 2, ID: 1, Key: "ab"},
})
testPredict(t, ptor, "bc", []prediction{
{Start: 0, End: 2, ID: 2, Key: "bc"},
})
testPredict(t, ptor, "bab", []prediction{
{Start: 0, End: 3, ID: 3, Key: "bab"},
{Start: 1, End: 3, ID: 1, Key: "ab"},
})
testPredict(t, ptor, "d", []prediction{
{Start: 0, End: 1, ID: 4, Key: "d"},
})
testPredict(t, ptor, "abcde", []prediction{
{Start: 0, End: 2, ID: 1, Key: "ab"},
{Start: 1, End: 3, ID: 2, Key: "bc"},
{Start: 3, End: 4, ID: 4, Key: "d"},
{Start: 0, End: 5, ID: 5, Key: "abcde"},
})
}

func testPredictAll(t *testing.T, builder predictorBuilder) {
t.Run("single", func(t *testing.T) {
testPredictSingle(t, builder)
})
t.Run("multiple", func(t *testing.T) {
testPredictMultiple(t, builder)
})
t.Run("basic", func(t *testing.T) {
testPredictBasic(t, builder)
})
}

func TestPredictSeq(t *testing.T) {
t.Run("dynamic", func(t *testing.T) {
testPredictAll(t, func(t *testing.T, keys ...string) predictor {
return testDTreePut(t, &trietree.DTree{}, keys...)
})
})
t.Run("static", func(t *testing.T) {
testPredictAll(t, func(t *testing.T, keys ...string) predictor {
dt := testDTreePut(t, &trietree.DTree{}, keys...)
return trietree.Freeze(dt)
})
})
}
31 changes: 30 additions & 1 deletion trie2/predict.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package trie2

import "github.com/koron-go/trietree"
import (
"iter"

"github.com/koron-go/trietree"
)

type Prediction[T any] struct {
Start int // Start is the start index of Key in the query.
Expand Down Expand Up @@ -34,3 +38,28 @@ func (dt *DTrie[T]) PredictIter(query string) PredictionIter[T] {
func (st *STrie[T]) PredictIter(query string) PredictionIter[T] {
return predictIter(query, st.tree.PredictIter(query), st.values)
}

func predict[T any](query string, iter iter.Seq[trietree.Prediction], values []T) iter.Seq[Prediction[T]] {
return func(yield func(Prediction[T]) bool) {
for p := range iter {
if !yield(Prediction[T]{
Start: p.Start,
End: p.End,
Key: query[p.Start:p.End],
Value: values[p.ID-1],
}) {
break
}
}
}
}

// Predict returns an iterator which enumerates Prediction.
func (dt *DTrie[T]) Predict(query string) iter.Seq[Prediction[T]] {
return predict[T](query, dt.tree.Predict(query), dt.values)
}

// Predict returns an iterator which enumerates Prediction.
func (st *STrie[T]) Predict(query string) iter.Seq[Prediction[T]] {
return predict[T](query, st.tree.Predict(query), st.values)
}
43 changes: 42 additions & 1 deletion trie2/predict_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package trie2

import (
"fmt"
"iter"
"slices"
"testing"

"github.com/google/go-cmp/cmp"
Expand All @@ -26,7 +28,7 @@ func testPredictIter[T any](t *testing.T, ptor predictIterator[T], q string, wan
}
}

func TestPredict(t *testing.T) {
func TestPredictIter(t *testing.T) {
dt := &DTrie[Data]{}
dt.Put("a", Data{111, "aaa"})
dt.Put("ab", Data{222, "bbb"})
Expand All @@ -53,3 +55,42 @@ func TestPredict(t *testing.T) {
})
}
}

type predictor[T any] interface {
Predict(string) iter.Seq[Prediction[T]]
}

func testPredict[T any](t *testing.T, ptor predictor[T], q string, want []Prediction[T]) {
got := slices.Collect[Prediction[T]](ptor.Predict(q))
if d := cmp.Diff(want, got); d != "" {
t.Errorf("unexpected predictions: -want +got\n%s", d)
}
}

func TestPredict(t *testing.T) {
dt := &DTrie[Data]{}
dt.Put("a", Data{111, "aaa"})
dt.Put("ab", Data{222, "bbb"})
dt.Put("abc", Data{333, "ccc"})
dt.Put("d", Data{444, "ddd"})
dt.Put("de", Data{555, "eee"})
dt.FillFailure()
st := dt.Freeze(false)

for i, c := range []struct {
q string
want []Prediction[Data]
}{
{"azd", []Prediction[Data]{
{Start: 0, End: 1, Key: "a", Value: Data{111, "aaa"}},
{Start: 2, End: 3, Key: "d", Value: Data{444, "ddd"}},
}},
} {
t.Run(fmt.Sprintf("DTrie i:%d q:%s", i, c.q), func(t *testing.T) {
testPredict(t, dt, c.q, c.want)
})
t.Run(fmt.Sprintf("STrie i:%d q:%s", i, c.q), func(t *testing.T) {
testPredict(t, st, c.q, c.want)
})
}
}
Loading