-
Notifications
You must be signed in to change notification settings - Fork 3
/
trie.go
274 lines (223 loc) · 7.24 KB
/
trie.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
// Package trie contains a primitive implementation of the Trie data structure.
//
// Copyright 2017 Aleksandr Bezobchuk.
package trie
import (
"container/list"
"sync"
)
// Bytes reflects a type alias for a byte slice
type Bytes []byte
// trieNode implements a node that the Trie is composed of. Each node contains
// a symbol that a key can be composed of unless the node is the root. The node
// has a collection of children that is represented as a hashmap, although,
// traditionally an array is used to represent each symbol in the given
// alphabet. The node may also contain a value that indicates a possible query
// result.
//
// TODO: Handle the case where the value given is a dummy value which can be
// nil. Perhaps it's best to not store values at all.
type trieNode struct {
children map[byte]*trieNode
symbol byte
value []byte
root bool
}
// Trie implements a thread-safe search tree that stores byte key value pairs
// and allows for efficient queries.
type Trie struct {
rw sync.RWMutex
root *trieNode
size int
}
// NewTrie returns a new initialized empty Trie.
func NewTrie() *Trie {
return &Trie{
root: &trieNode{root: true, children: make(map[byte]*trieNode)},
size: 1,
}
}
func newNode(symbol byte) *trieNode {
return &trieNode{children: make(map[byte]*trieNode), symbol: symbol}
}
// Size returns the total number of nodes in the trie. The size includes the
// root node.
func (t *Trie) Size() int {
t.rw.RLock()
defer t.rw.RUnlock()
return t.size
}
// Insert inserts a key value pair into the trie. If the key already exists,
// the value is updated. Insertion is performed by starting at the root
// and traversing the nodes all the way down until the key is exhausted. Once
// exhausted, the currNode pointer should be a pointer to the last symbol in
// the key and reflect the terminating node for that key value pair.
func (t *Trie) Insert(key, value Bytes) {
t.rw.Lock()
defer t.rw.Unlock()
currNode := t.root
for _, symbol := range key {
if currNode.children[symbol] == nil {
currNode.children[symbol] = newNode(symbol)
}
currNode = currNode.children[symbol]
}
// Only increment size if the key value pair is new, otherwise we consider
// the operation as an update.
if currNode.value == nil {
t.size++
}
currNode.value = value
}
// Search attempts to search for a value in the trie given a key. If such a key
// exists, it's value is returned along with a boolean to reflect that the key
// exists. Otherwise, an empty value and false is returned.
func (t *Trie) Search(key Bytes) (Bytes, bool) {
t.rw.RLock()
defer t.rw.RUnlock()
currNode := t.root
for _, symbol := range key {
if currNode.children[symbol] == nil {
return nil, false
}
currNode = currNode.children[symbol]
}
return currNode.value, true
}
// GetAllKeys returns all the keys that exist in the trie. Keys are retrieved
// by performing a DFS on the trie where at each node we keep track of the
// current path (key) traversed thusfar and if that node has a value. If so,
// the full path (key) is appended to a list. After the trie search is
// exhausted, the final list is returned.
func (t *Trie) GetAllKeys() []Bytes {
visited := make(map[*trieNode]bool)
keys := []Bytes{}
var dfsGetKeys func(n *trieNode, key Bytes)
dfsGetKeys = func(n *trieNode, key Bytes) {
if n != nil {
pathKey := append(key, n.symbol)
visited[n] = true
if n.value != nil {
fullKey := make(Bytes, len(pathKey))
// Copy the contents of the current path (key) to a new key so
// future recursive calls will contain the correct bytes.
copy(fullKey, pathKey)
// Append the path (key) to the key list ignoring the first
// byte which is the root symbol.
keys = append(keys, fullKey[1:])
}
for _, child := range n.children {
if _, ok := visited[child]; !ok {
dfsGetKeys(child, pathKey)
}
}
}
}
dfsGetKeys(t.root, Bytes{})
return keys
}
// GetAllValues returns all the values that exist in the trie. Values are
// retrieved by performing a BFS on the trie where at each node we examine if
// that node has a value. If so, that value is appended to a list. After the
// trie search is exhausted, the final list is returned.
func (t *Trie) GetAllValues() []Bytes {
queue := list.New()
visited := make(map[*trieNode]bool)
values := []Bytes{}
queue.PushBack(t.root)
for queue.Len() > 0 {
element := queue.Front()
queue.Remove(element)
node := element.Value.(*trieNode)
visited[node] = true
if node.value != nil {
values = append(values, node.value)
}
for _, child := range node.children {
_, ok := visited[child]
if !ok {
queue.PushBack(child)
}
}
}
return values
}
// GetPrefixKeys returns all the keys that exist in the trie such that each key
// contains a specified prefix. Keys are retrieved by performing a DFS on the
// trie where at each node we keep track of the current path (key) and prefix
// traversed thusfar. If a node has a value the full path (key) is appended to
// a list. After the trie search is exhausted, the final list is returned.
func (t *Trie) GetPrefixKeys(prefix Bytes) []Bytes {
visited := make(map[*trieNode]bool)
keys := []Bytes{}
if len(prefix) == 0 {
return keys
}
var dfsGetPrefixKeys func(n *trieNode, prefixIdx int, key Bytes)
dfsGetPrefixKeys = func(n *trieNode, prefixIdx int, key Bytes) {
if n != nil {
pathKey := append(key, n.symbol)
if prefixIdx == len(prefix) || n.symbol == prefix[prefixIdx] {
visited[n] = true
if n.value != nil {
fullKey := make(Bytes, len(pathKey))
// Copy the contents of the current path (key) to a new key
// so future recursive calls will contain the correct
// bytes.
copy(fullKey, pathKey)
keys = append(keys, fullKey)
}
if prefixIdx < len(prefix) {
prefixIdx++
}
for _, child := range n.children {
if _, ok := visited[child]; !ok {
dfsGetPrefixKeys(child, prefixIdx, pathKey)
}
}
}
}
}
// Find starting node from the root's children
if n, ok := t.root.children[prefix[0]]; ok {
dfsGetPrefixKeys(n, 0, Bytes{})
}
return keys
}
// GetPrefixValues returns all the values that exist in the trie such that each
// key that corresponds to that value contains a specified prefix. Values are
// retrieved by performing a DFS on the trie where at each node we check if the
// prefix is exhausted or matches thusfar and the current node has a value. If
// the current node has a value, it is appended to a list. After the trie
// search is exhausted, the final list is returned.
func (t *Trie) GetPrefixValues(prefix Bytes) []Bytes {
visited := make(map[*trieNode]bool)
values := []Bytes{}
if len(prefix) == 0 {
return values
}
var dfsGetPrefixValues func(n *trieNode, prefixIdx int)
dfsGetPrefixValues = func(n *trieNode, prefixIdx int) {
if n != nil {
if prefixIdx == len(prefix) || n.symbol == prefix[prefixIdx] {
visited[n] = true
if n.value != nil {
values = append(values, n.value)
}
if prefixIdx < len(prefix) {
prefixIdx++
}
for _, child := range n.children {
if _, ok := visited[child]; !ok {
dfsGetPrefixValues(child, prefixIdx)
}
}
}
}
}
// Find starting node from the root's children
if n, ok := t.root.children[prefix[0]]; ok {
dfsGetPrefixValues(n, 0)
}
return values
}