Skip to content

Commit

Permalink
Support save and restore of index (#2)
Browse files Browse the repository at this point in the history
* Normalize to speed up distance calc

* Support save and restore of index

* fix
  • Loading branch information
kelindar authored Oct 28, 2024
1 parent 2c06460 commit 84112d2
Show file tree
Hide file tree
Showing 9 changed files with 256 additions and 123 deletions.
75 changes: 0 additions & 75 deletions bruteforce_test.go

This file was deleted.

3 changes: 3 additions & 0 deletions dist/dataset.bin
Git LFS file not shown
3 changes: 0 additions & 3 deletions dist/dataset.gob

This file was deleted.

34 changes: 6 additions & 28 deletions example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"bufio"
"encoding/gob"
"fmt"
"math"
"os"
Expand All @@ -21,13 +20,7 @@ func main() {
defer m.Close()

// Load a pre-embedded dataset and create an exact search index
data, _ := loadDataset("../dist/dataset.gob")
index := search.NewIndex[string]()

// Embed the sentences and calculate similarities
for _, v := range data {
index.Add(v.Vector, v.Pair[0]) // use m.EmbedText() for real-time embedding
}
index := loadIndex("../dist/dataset.bin")

r := bufio.NewReader(os.Stdin)
for {
Expand Down Expand Up @@ -63,27 +56,12 @@ func main() {
}
}

type record struct {
Pair [2]string `gob:"pair"`
Rank float64 `gob:"rank"`
Label string `gob:"label"`
Vector []float32 `gob:"vector"`
}

func loadDataset(path string) ([]record, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()

var data []record
r := gob.NewDecoder(file)
if err := r.Decode(&data); err != nil {
return nil, err
func loadIndex(path string) *search.Index[string] {
index := search.NewIndex[string]()
if err := index.ReadFile(path); err != nil {
panic(err)
}

return data, nil
return index
}

/*
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.23

require (
github.com/ebitengine/purego v0.8.1
github.com/kelindar/iostream v1.4.0
github.com/klauspost/cpuid/v2 v2.2.8
github.com/stretchr/testify v1.9.0
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/ebitengine/purego v0.8.1 h1:sdRKd6plj7KYW33EH5As6YKfe8m9zbN9JMrOjNVF/BE=
github.com/ebitengine/purego v0.8.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/kelindar/iostream v1.4.0 h1:ELKlinnM/K3GbRp9pYhWuZOyBxMMlYAfsOP+gauvZaY=
github.com/kelindar/iostream v1.4.0/go.mod h1:MkjMuVb6zGdPQVdwLnFRO0xOTOdDvBWTztFmjRDQkXk=
github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM=
github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
Expand Down
40 changes: 23 additions & 17 deletions bruteforce.go → index.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ type entry[T any] struct {

// Result represents a search result.
type Result[T any] struct {
entry[T]
Relevance float64 // The relevance of the result
Value T // The value of the result
}

// Index represents a brute-force search index, returning exact results.
Expand All @@ -31,46 +31,52 @@ type Index[T any] struct {
// NewIndex creates a new exact search index.
func NewIndex[T any]() *Index[T] {
return &Index[T]{
arr: make([]entry[T], 0),
arr: make([]entry[T], 0, 512),
}
}

// Len returns the number of items in the index.
func (idx *Index[T]) Len() int {
return len(idx.arr)
}

// Add adds a new vector to the search index.
func (b *Index[T]) Add(vx Vector, item T) {
func (idx *Index[T]) Add(vx Vector, item T) {
normalize(vx)

b.arr = append(b.arr, entry[T]{
idx.arr = append(idx.arr, entry[T]{
Vector: vx,
Value: item,
})
}

// Search searches the index for the k-nearest neighbors of the query vector.
func (b *Index[T]) Search(query Vector, k int) []Result[T] {
func (idx *Index[T]) Search(query Vector, k int) []Result[T] {
if k <= 0 {
return nil
}

// Normalize and quantize the query vector
// Normalize the query vector
normalize(query)

var relevance float64
var r float64
dst := make(minheap[T], 0, k)
for _, v := range b.arr {
simd.DotProduct(&relevance, query, v.Vector)
result := Result[T]{
entry: v,
Relevance: relevance,
}
for _, v := range idx.arr {
simd.DotProduct(&r, query, v.Vector)

// If the heap is not full, add the result, otherwise replace
// the minimum element
switch {
case dst.Len() < k:
dst.Push(result)
case result.Relevance > dst[0].Relevance:
dst.Push(Result[T]{
Value: v.Value,
Relevance: r,
})
case r > dst[0].Relevance:
dst.Pop()
dst.Push(result)
dst.Push(Result[T]{
Value: v.Value,
Relevance: r,
})
}
}

Expand Down
126 changes: 126 additions & 0 deletions index_codec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for details.

package search

import (
"compress/flate"
"fmt"
"io"
"os"

"github.com/kelindar/iostream"
)

// WriteTo writes the index to a writer.
func (b *Index[T]) WriteTo(dst io.Writer) (int64, error) {
w := iostream.NewWriter(dst)
i := w.Offset()

// Write version
if err := w.WriteUint8(1); err != nil {
return 0, err
}

// Write the index
err := w.WriteRange(len(b.arr), func(i int, w *iostream.Writer) error {
if err := w.WriteFloat32s(b.arr[i].Vector); err != nil {
return err
}

// Write the value (optional)
switch v := any(b.arr[i].Value).(type) {
case string:
return w.WriteString(v)
case []byte:
return w.WriteBytes(v)
default:
return nil
}
})

return w.Offset() - i, err
}

// ReadFrom reads the index from a reader.
func (b *Index[T]) ReadFrom(src io.Reader) (int64, error) {
r := iostream.NewReader(src)
s := r.Offset()

// Read version
version, err := r.ReadUint8()
if err != nil {
return 0, err
}

if version != 1 {
return 0, fmt.Errorf("unsupported version: %d", version)
}

var length uint64
if length, err = r.ReadUvarint(); err != nil {
return r.Offset() - s, err
}

// Allocate space for the entries
b.arr = make([]entry[T], length)
for i := 0; i < int(length); i++ {

// Read the vector
if b.arr[i].Vector, err = r.ReadFloat32s(); err != nil {
return r.Offset() - s, err
}

// Read the value (optional)
switch any(b.arr[i].Value).(type) {
case string:
v, err := r.ReadString()
if err != nil {
return r.Offset() - s, err
}
b.arr[i].Value = any(v).(T)

case []byte:
v, err := r.ReadBytes()
if err != nil {
return r.Offset() - s, err
}
b.arr[i].Value = any(v).(T)
}
}

return r.Offset() - s, nil
}

// ---------------------------------- File ----------------------------------

// WriteFile writes the index into a flate-compressed binary file.
func (idx *Index[T]) WriteFile(filename string) error {
file, err := os.Create(filename)
if err != nil {
return err
}

defer file.Close()
writer, err := flate.NewWriter(file, flate.DefaultCompression)
if err != nil {
return err
}

// WriteTo the underlying writer
defer writer.Close()
_, err = idx.WriteTo(writer)
return err
}

// ReadFile reads the index from a flate-compressed binary file.
func (idx *Index[T]) ReadFile(filename string) error {
file, err := os.Open(filename)
if err != nil {
return err
}

defer file.Close()
_, err = idx.ReadFrom(flate.NewReader(file))
return err
}
Loading

0 comments on commit 84112d2

Please sign in to comment.