Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Move GetResults to types #89

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 54 additions & 4 deletions chroma.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package chromago

Check failure on line 1 in chroma.go

View workflow job for this annotation

GitHub Actions / lint

: # github.com/amikos-tech/chroma-go

Check failure on line 1 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.4.8)

: # github.com/amikos-tech/chroma-go

Check failure on line 1 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.4.24)

: # github.com/amikos-tech/chroma-go

Check failure on line 1 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.5.0)

: # github.com/amikos-tech/chroma-go

Check failure on line 1 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.5.2)

: # github.com/amikos-tech/chroma-go

import (
"context"
Expand Down Expand Up @@ -385,10 +385,32 @@
}

type GetResults struct {
Ids []string
Documents []string
Metadatas []map[string]interface{}
Embeddings []*types.Embedding
Ids []string
Documents []string
Metadatas []map[string]interface{}
Embeddings []*types.Embedding
collection *Collection
PageInfo *types.PageInfo
RecordIterator chan *types.Record
}

func (r *GetResults) NextPage(ctx context.Context) (*GetResults, error) {
if r.PageInfo == nil {
return nil, fmt.Errorf("no page info. Your Get must contain limit option")
}
r.PageInfo.QueryOptions = append(r.PageInfo.QueryOptions, types.WithOffset(r.PageInfo.Offset+r.PageInfo.Limit))
return r.collection.GetWithOptions(ctx, r.PageInfo.QueryOptions...)
}

func (r *GetResults) PreviousPage(ctx context.Context) (*GetResults, error) {
if r.PageInfo == nil {
return nil, fmt.Errorf("no page info. Your Get must contain limit options")
}
if r.PageInfo.Offset-r.PageInfo.Limit < 0 {
return nil, fmt.Errorf("cannot go to previous page. Offset is less than 0")
}
r.PageInfo.QueryOptions = append(r.PageInfo.QueryOptions, types.WithOffset(r.PageInfo.Offset-r.PageInfo.Limit))
return r.collection.GetWithOptions(ctx, r.PageInfo.QueryOptions...)
}

type Collection struct {
Expand Down Expand Up @@ -562,7 +584,35 @@
Documents: cd.Documents,
Metadatas: cd.Metadatas,
Embeddings: APIEmbeddingsToEmbeddings(cd.Embeddings),
collection: c,
}
// only add PageInfo when both limit (offset is assumed 0 if not provided)
// TODO can limit == collection count (that is an extra API call though)
if query.Limit != 0 {
results.PageInfo = &types.PageInfo{
Limit: query.Limit,
Offset: query.Offset,
QueryOptions: options,
}
}

results.RecordIterator = make(chan *types.Record)
go func() {
count, err := c.Count(ctx)
if err != nil {
return
}
start := 0
end := count
lim := 10
if query.Limit != 0 {
lim = query.Limit

Check failure on line 609 in chroma.go

View workflow job for this annotation

GitHub Actions / lint

cannot use query.Limit (variable of type int32) as int value in assignment

Check failure on line 609 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.4.8)

cannot use query.Limit (variable of type int32) as int value in assignment

Check failure on line 609 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.4.24)

cannot use query.Limit (variable of type int32) as int value in assignment

Check failure on line 609 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.5.0)

cannot use query.Limit (variable of type int32) as int value in assignment

Check failure on line 609 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.5.2)

cannot use query.Limit (variable of type int32) as int value in assignment
}
for i := start; i < end; i += lim {

Check failure on line 611 in chroma.go

View workflow job for this annotation

GitHub Actions / lint

invalid operation: i < end (mismatched types int and int32)

Check failure on line 611 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.4.8)

invalid operation: i < end (mismatched types int and int32)

Check failure on line 611 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.4.24)

invalid operation: i < end (mismatched types int and int32)

Check failure on line 611 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.5.0)

invalid operation: i < end (mismatched types int and int32)

Check failure on line 611 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.5.2)

invalid operation: i < end (mismatched types int and int32)
c.GetWithOptions(ctx, types.WithLimit(10), types.WithOffset(i))

Check failure on line 612 in chroma.go

View workflow job for this annotation

GitHub Actions / lint

cannot use i (variable of type int) as int32 value in argument to types.WithOffset

Check failure on line 612 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.4.8)

cannot use i (variable of type int) as int32 value in argument to types.WithOffset

Check failure on line 612 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.4.24)

cannot use i (variable of type int) as int32 value in argument to types.WithOffset

Check failure on line 612 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.5.0)

cannot use i (variable of type int) as int32 value in argument to types.WithOffset

Check failure on line 612 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.5.2)

cannot use i (variable of type int) as int32 value in argument to types.WithOffset
}
close(ch)

Check failure on line 614 in chroma.go

View workflow job for this annotation

GitHub Actions / lint

undefined: ch (typecheck)

Check failure on line 614 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.4.8)

undefined: ch (typecheck)

Check failure on line 614 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.4.24)

undefined: ch (typecheck)

Check failure on line 614 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.5.0)

undefined: ch (typecheck)

Check failure on line 614 in chroma.go

View workflow job for this annotation

GitHub Actions / build (0.5.2)

undefined: ch (typecheck)
}()
return results, nil
}

Expand Down
120 changes: 120 additions & 0 deletions test/chroma_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1123,4 +1123,124 @@ func Test_chroma_client(t *testing.T) {
require.Equal(t, "test-collection", resp.Name, "Collection name should be test-collection")
require.NotNil(t, resp.ID, "Collection id should not be nil")
})

t.Run("Test Get with PageInfo", func(t *testing.T) {
collectionName := "test-collection"
metadata := map[string]interface{}{}
embeddingFunction := types.NewConsistentHashEmbeddingFunction()
_, errRest := client.Reset(context.Background())
require.NoError(t, errRest)
newCollection, err := client.CreateCollection(context.Background(), collectionName, metadata, true, embeddingFunction, types.L2)
require.NoError(t, err)
require.NotNil(t, newCollection)
require.Equal(t, collectionName, newCollection.Name)
require.Equal(t, 2, len(newCollection.Metadata))
// assert the metadata contains key embedding_function
require.Contains(t, chroma.GetStringTypeOfEmbeddingFunction(embeddingFunction), newCollection.Metadata["embedding_function"])
docs := make([]string, 0, 100)
ids := make([]string, 0, 100)
for i := 1; i <= 100; i++ {
doc := fmt.Sprintf("Document %d content", i)
docs = append(docs, doc)
ids = append(ids, fmt.Sprintf("ID%d", i))
}
col, addError := newCollection.Add(context.Background(), nil, nil, docs, ids)
require.NoError(t, addError)
results, err := col.GetWithOptions(context.Background(), types.WithOffset(0), types.WithLimit(5))
require.NoError(t, err)
require.Len(t, results.Ids, 5)
newResults, err := results.NextPage(context.Background())
require.NoError(t, err)
require.Len(t, newResults.Ids, 5)
})

t.Run("Test Get with PageInfo overflow", func(t *testing.T) {
collectionName := "test-collection"
metadata := map[string]interface{}{}
embeddingFunction := types.NewConsistentHashEmbeddingFunction()
_, errRest := client.Reset(context.Background())
require.NoError(t, errRest)
newCollection, err := client.CreateCollection(context.Background(), collectionName, metadata, true, embeddingFunction, types.L2)
require.NoError(t, err)
require.NotNil(t, newCollection)
require.Equal(t, collectionName, newCollection.Name)
require.Equal(t, 2, len(newCollection.Metadata))
// assert the metadata contains key embedding_function
require.Contains(t, chroma.GetStringTypeOfEmbeddingFunction(embeddingFunction), newCollection.Metadata["embedding_function"])
docs := make([]string, 0, 5)
ids := make([]string, 0, 5)
for i := 1; i <= 5; i++ {
doc := fmt.Sprintf("Document %d content", i)
docs = append(docs, doc)
ids = append(ids, fmt.Sprintf("ID%d", i))
}
col, addError := newCollection.Add(context.Background(), nil, nil, docs, ids)
require.NoError(t, addError)
results, err := col.GetWithOptions(context.Background(), types.WithOffset(0), types.WithLimit(5))
require.NoError(t, err)
require.Len(t, results.Ids, 5)
newResults, err := results.NextPage(context.Background())
require.NoError(t, err)
require.Len(t, newResults.Ids, 0)
})

t.Run("Test Get with PageInfo negative offset", func(t *testing.T) {
collectionName := "test-collection"
metadata := map[string]interface{}{}
embeddingFunction := types.NewConsistentHashEmbeddingFunction()
_, errRest := client.Reset(context.Background())
require.NoError(t, errRest)
newCollection, err := client.CreateCollection(context.Background(), collectionName, metadata, true, embeddingFunction, types.L2)
require.NoError(t, err)
require.NotNil(t, newCollection)
require.Equal(t, collectionName, newCollection.Name)
require.Equal(t, 2, len(newCollection.Metadata))
// assert the metadata contains key embedding_function
require.Contains(t, chroma.GetStringTypeOfEmbeddingFunction(embeddingFunction), newCollection.Metadata["embedding_function"])
docs := make([]string, 0, 5)
ids := make([]string, 0, 5)
for i := 1; i <= 5; i++ {
doc := fmt.Sprintf("Document %d content", i)
docs = append(docs, doc)
ids = append(ids, fmt.Sprintf("ID%d", i))
}
col, addError := newCollection.Add(context.Background(), nil, nil, docs, ids)
require.NoError(t, addError)
results, err := col.GetWithOptions(context.Background(), types.WithOffset(0), types.WithLimit(5))
require.NoError(t, err)
require.Len(t, results.Ids, 5)
_, err = results.PreviousPage(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "cannot go to previous page")
})

t.Run("Test Get with PageInfo nil without limit/offset", func(t *testing.T) {
collectionName := "test-collection"
metadata := map[string]interface{}{}
embeddingFunction := types.NewConsistentHashEmbeddingFunction()
_, errRest := client.Reset(context.Background())
require.NoError(t, errRest)
newCollection, err := client.CreateCollection(context.Background(), collectionName, metadata, true, embeddingFunction, types.L2)
require.NoError(t, err)
require.NotNil(t, newCollection)
require.Equal(t, collectionName, newCollection.Name)
require.Equal(t, 2, len(newCollection.Metadata))
// assert the metadata contains key embedding_function
require.Contains(t, chroma.GetStringTypeOfEmbeddingFunction(embeddingFunction), newCollection.Metadata["embedding_function"])
docs := make([]string, 0, 5)
ids := make([]string, 0, 5)
for i := 1; i <= 5; i++ {
doc := fmt.Sprintf("Document %d content", i)
docs = append(docs, doc)
ids = append(ids, fmt.Sprintf("ID%d", i))
}
col, addError := newCollection.Add(context.Background(), nil, nil, docs, ids)
require.NoError(t, addError)
results, err := col.GetWithOptions(context.Background())
require.NoError(t, err)
require.Len(t, results.Ids, 5)
_, err = results.PreviousPage(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "no page info")
})
}
6 changes: 6 additions & 0 deletions types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,3 +594,9 @@ func (t *TokenAuthCredentialsProvider) Authenticate(config *openapi.Configuratio
return fmt.Errorf("unsupported token header: %v", t.Header)
}
}

type PageInfo struct {
Limit int32
Offset int32
QueryOptions []CollectionQueryOption
}
Loading