From 18056b66e0e0d3a834a3bd4af9a7f3835c2512c2 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Wed, 26 Jun 2024 14:11:44 +0200 Subject: [PATCH 1/5] feat: Move GetResults to types --- chroma.go | 13 +++---------- types/types.go | 7 +++++++ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/chroma.go b/chroma.go index f06b8bb..787ed71 100644 --- a/chroma.go +++ b/chroma.go @@ -384,13 +384,6 @@ func (c *Client) Version(ctx context.Context) (string, error) { return version, err } -type GetResults struct { - Ids []string - Documents []string - Metadatas []map[string]interface{} - Embeddings []*types.Embedding -} - type Collection struct { Name string EmbeddingFunction types.EmbeddingFunction @@ -525,7 +518,7 @@ func (c *Collection) Modify(ctx context.Context, embeddings []*types.Embedding, return c, nil } -func (c *Collection) GetWithOptions(ctx context.Context, options ...types.CollectionQueryOption) (*GetResults, error) { +func (c *Collection) GetWithOptions(ctx context.Context, options ...types.CollectionQueryOption) (*types.GetResults, error) { query := &types.CollectionQueryBuilder{} for _, opt := range options { err := opt(query) @@ -557,7 +550,7 @@ func (c *Collection) GetWithOptions(ctx context.Context, options ...types.Collec return nil, err } - results := &GetResults{ + results := &types.GetResults{ Ids: cd.Ids, Documents: cd.Documents, Metadatas: cd.Metadatas, @@ -566,7 +559,7 @@ func (c *Collection) GetWithOptions(ctx context.Context, options ...types.Collec return results, nil } -func (c *Collection) Get(ctx context.Context, where map[string]interface{}, whereDocuments map[string]interface{}, ids []string, include []types.QueryEnum) (*GetResults, error) { +func (c *Collection) Get(ctx context.Context, where map[string]interface{}, whereDocuments map[string]interface{}, ids []string, include []types.QueryEnum) (*types.GetResults, error) { return c.GetWithOptions(ctx, types.WithWhereMap(where), types.WithWhereDocumentMap(whereDocuments), types.WithIds(ids), types.WithInclude(include...)) } diff --git a/types/types.go b/types/types.go index a0727bb..3571bfe 100644 --- a/types/types.go +++ b/types/types.go @@ -594,3 +594,10 @@ func (t *TokenAuthCredentialsProvider) Authenticate(config *openapi.Configuratio return fmt.Errorf("unsupported token header: %v", t.Header) } } + +type GetResults struct { + Ids []string + Documents []string + Metadatas []map[string]interface{} + Embeddings []*Embedding +} From 01b3ff5a9e2d7d3fec8acf422b9a28410b7b9750 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Wed, 26 Jun 2024 15:53:52 +0200 Subject: [PATCH 2/5] feat: Page info --- chroma.go | 43 ++++++++++++++++-- test/chroma_client_test.go | 90 ++++++++++++++++++++++++++++++++++++++ types/types.go | 9 ++-- 3 files changed, 134 insertions(+), 8 deletions(-) diff --git a/chroma.go b/chroma.go index 787ed71..5277491 100644 --- a/chroma.go +++ b/chroma.go @@ -384,6 +384,34 @@ func (c *Client) Version(ctx context.Context) (string, error) { return version, err } +type GetResults struct { + Ids []string + Documents []string + Metadatas []map[string]interface{} + Embeddings []*types.Embedding + collection *Collection + PageInfo *types.PageInfo +} + +func (r *GetResults) NextPage(ctx context.Context) (*GetResults, error) { + if r.PageInfo == nil { + return nil, fmt.Errorf("no page info. Your Get must contain limit and offset options") + } + r.PageInfo.QueryOptions = append(r.PageInfo.QueryOptions, types.WithOffset(r.PageInfo.Offset+r.PageInfo.Limit)) + return r.collection.GetWithOptions(ctx, r.PageInfo.QueryOptions...) +} + +func (r *GetResults) PreviousPage(ctx context.Context) (*GetResults, error) { + if r.PageInfo == nil { + return nil, fmt.Errorf("no page info. Your Get must contain limit and offset options") + } + if r.PageInfo.Offset-r.PageInfo.Limit < 0 { + return nil, fmt.Errorf("cannot go to previous page. Offset is less than 0") + } + r.PageInfo.QueryOptions = append(r.PageInfo.QueryOptions, types.WithOffset(r.PageInfo.Offset-r.PageInfo.Limit)) + return r.collection.GetWithOptions(ctx, r.PageInfo.QueryOptions...) +} + type Collection struct { Name string EmbeddingFunction types.EmbeddingFunction @@ -518,7 +546,7 @@ func (c *Collection) Modify(ctx context.Context, embeddings []*types.Embedding, return c, nil } -func (c *Collection) GetWithOptions(ctx context.Context, options ...types.CollectionQueryOption) (*types.GetResults, error) { +func (c *Collection) GetWithOptions(ctx context.Context, options ...types.CollectionQueryOption) (*GetResults, error) { query := &types.CollectionQueryBuilder{} for _, opt := range options { err := opt(query) @@ -550,16 +578,25 @@ func (c *Collection) GetWithOptions(ctx context.Context, options ...types.Collec return nil, err } - results := &types.GetResults{ + results := &GetResults{ Ids: cd.Ids, Documents: cd.Documents, Metadatas: cd.Metadatas, Embeddings: APIEmbeddingsToEmbeddings(cd.Embeddings), + collection: c, + } + // only add PageInfo when both limit and offset are set + if &query.Limit != nil && &query.Offset != nil { + results.PageInfo = &types.PageInfo{ + Limit: query.Limit, + Offset: query.Offset, + QueryOptions: options, + } } return results, nil } -func (c *Collection) Get(ctx context.Context, where map[string]interface{}, whereDocuments map[string]interface{}, ids []string, include []types.QueryEnum) (*types.GetResults, error) { +func (c *Collection) Get(ctx context.Context, where map[string]interface{}, whereDocuments map[string]interface{}, ids []string, include []types.QueryEnum) (*GetResults, error) { return c.GetWithOptions(ctx, types.WithWhereMap(where), types.WithWhereDocumentMap(whereDocuments), types.WithIds(ids), types.WithInclude(include...)) } diff --git a/test/chroma_client_test.go b/test/chroma_client_test.go index 036db30..634f777 100644 --- a/test/chroma_client_test.go +++ b/test/chroma_client_test.go @@ -1123,4 +1123,94 @@ func Test_chroma_client(t *testing.T) { require.Equal(t, "test-collection", resp.Name, "Collection name should be test-collection") require.NotNil(t, resp.ID, "Collection id should not be nil") }) + + t.Run("Test Get with PageInfo", func(t *testing.T) { + collectionName := "test-collection" + metadata := map[string]interface{}{} + embeddingFunction := types.NewConsistentHashEmbeddingFunction() + _, errRest := client.Reset(context.Background()) + require.NoError(t, errRest) + newCollection, err := client.CreateCollection(context.Background(), collectionName, metadata, true, embeddingFunction, types.L2) + require.NoError(t, err) + require.NotNil(t, newCollection) + require.Equal(t, collectionName, newCollection.Name) + require.Equal(t, 2, len(newCollection.Metadata)) + // assert the metadata contains key embedding_function + require.Contains(t, chroma.GetStringTypeOfEmbeddingFunction(embeddingFunction), newCollection.Metadata["embedding_function"]) + docs := make([]string, 0, 100) + ids := make([]string, 0, 100) + for i := 1; i <= 100; i++ { + doc := fmt.Sprintf("Document %d content", i) + docs = append(docs, doc) + ids = append(ids, fmt.Sprintf("ID%d", i)) + } + col, addError := newCollection.Add(context.Background(), nil, nil, docs, ids) + require.NoError(t, addError) + results, err := col.GetWithOptions(context.Background(), types.WithOffset(0), types.WithLimit(5)) + require.NoError(t, err) + require.Len(t, results.Ids, 5) + newResults, err := results.NextPage(context.Background()) + require.NoError(t, err) + require.Len(t, newResults.Ids, 5) + }) + + t.Run("Test Get with PageInfo overflow", func(t *testing.T) { + collectionName := "test-collection" + metadata := map[string]interface{}{} + embeddingFunction := types.NewConsistentHashEmbeddingFunction() + _, errRest := client.Reset(context.Background()) + require.NoError(t, errRest) + newCollection, err := client.CreateCollection(context.Background(), collectionName, metadata, true, embeddingFunction, types.L2) + require.NoError(t, err) + require.NotNil(t, newCollection) + require.Equal(t, collectionName, newCollection.Name) + require.Equal(t, 2, len(newCollection.Metadata)) + // assert the metadata contains key embedding_function + require.Contains(t, chroma.GetStringTypeOfEmbeddingFunction(embeddingFunction), newCollection.Metadata["embedding_function"]) + docs := make([]string, 0, 5) + ids := make([]string, 0, 5) + for i := 1; i <= 5; i++ { + doc := fmt.Sprintf("Document %d content", i) + docs = append(docs, doc) + ids = append(ids, fmt.Sprintf("ID%d", i)) + } + col, addError := newCollection.Add(context.Background(), nil, nil, docs, ids) + require.NoError(t, addError) + results, err := col.GetWithOptions(context.Background(), types.WithOffset(0), types.WithLimit(5)) + require.NoError(t, err) + require.Len(t, results.Ids, 5) + newResults, err := results.NextPage(context.Background()) + require.NoError(t, err) + require.Len(t, newResults.Ids, 0) + }) + + t.Run("Test Get with PageInfo negative offset", func(t *testing.T) { + collectionName := "test-collection" + metadata := map[string]interface{}{} + embeddingFunction := types.NewConsistentHashEmbeddingFunction() + _, errRest := client.Reset(context.Background()) + require.NoError(t, errRest) + newCollection, err := client.CreateCollection(context.Background(), collectionName, metadata, true, embeddingFunction, types.L2) + require.NoError(t, err) + require.NotNil(t, newCollection) + require.Equal(t, collectionName, newCollection.Name) + require.Equal(t, 2, len(newCollection.Metadata)) + // assert the metadata contains key embedding_function + require.Contains(t, chroma.GetStringTypeOfEmbeddingFunction(embeddingFunction), newCollection.Metadata["embedding_function"]) + docs := make([]string, 0, 5) + ids := make([]string, 0, 5) + for i := 1; i <= 5; i++ { + doc := fmt.Sprintf("Document %d content", i) + docs = append(docs, doc) + ids = append(ids, fmt.Sprintf("ID%d", i)) + } + col, addError := newCollection.Add(context.Background(), nil, nil, docs, ids) + require.NoError(t, addError) + results, err := col.GetWithOptions(context.Background(), types.WithOffset(0), types.WithLimit(5)) + require.NoError(t, err) + require.Len(t, results.Ids, 5) + _, err = results.PreviousPage(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot go to previous page") + }) } diff --git a/types/types.go b/types/types.go index 3571bfe..c507fcc 100644 --- a/types/types.go +++ b/types/types.go @@ -595,9 +595,8 @@ func (t *TokenAuthCredentialsProvider) Authenticate(config *openapi.Configuratio } } -type GetResults struct { - Ids []string - Documents []string - Metadatas []map[string]interface{} - Embeddings []*Embedding +type PageInfo struct { + Limit int32 + Offset int32 + QueryOptions []CollectionQueryOption } From 21fa07046d43a6bc4831f0f41178ce0224ef6fda Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Thu, 27 Jun 2024 09:55:14 +0200 Subject: [PATCH 3/5] fix: Fixed a bug + boundary test --- chroma.go | 2 +- test/chroma_client_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/chroma.go b/chroma.go index 5277491..d82309c 100644 --- a/chroma.go +++ b/chroma.go @@ -586,7 +586,7 @@ func (c *Collection) GetWithOptions(ctx context.Context, options ...types.Collec collection: c, } // only add PageInfo when both limit and offset are set - if &query.Limit != nil && &query.Offset != nil { + if query.Limit != 0 && query.Offset != 0 { results.PageInfo = &types.PageInfo{ Limit: query.Limit, Offset: query.Offset, diff --git a/test/chroma_client_test.go b/test/chroma_client_test.go index 634f777..dca262d 100644 --- a/test/chroma_client_test.go +++ b/test/chroma_client_test.go @@ -1213,4 +1213,34 @@ func Test_chroma_client(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "cannot go to previous page") }) + + t.Run("Test Get with PageInfo nil without limit/offset", func(t *testing.T) { + collectionName := "test-collection" + metadata := map[string]interface{}{} + embeddingFunction := types.NewConsistentHashEmbeddingFunction() + _, errRest := client.Reset(context.Background()) + require.NoError(t, errRest) + newCollection, err := client.CreateCollection(context.Background(), collectionName, metadata, true, embeddingFunction, types.L2) + require.NoError(t, err) + require.NotNil(t, newCollection) + require.Equal(t, collectionName, newCollection.Name) + require.Equal(t, 2, len(newCollection.Metadata)) + // assert the metadata contains key embedding_function + require.Contains(t, chroma.GetStringTypeOfEmbeddingFunction(embeddingFunction), newCollection.Metadata["embedding_function"]) + docs := make([]string, 0, 5) + ids := make([]string, 0, 5) + for i := 1; i <= 5; i++ { + doc := fmt.Sprintf("Document %d content", i) + docs = append(docs, doc) + ids = append(ids, fmt.Sprintf("ID%d", i)) + } + col, addError := newCollection.Add(context.Background(), nil, nil, docs, ids) + require.NoError(t, addError) + results, err := col.GetWithOptions(context.Background()) + require.NoError(t, err) + require.Len(t, results.Ids, 5) + _, err = results.PreviousPage(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "no page info") + }) } From b63359847fb47a628d2b09d8c5edb175ab90da86 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Thu, 27 Jun 2024 10:08:11 +0200 Subject: [PATCH 4/5] fix: Fixed boundary condition - only limit is required for pageInfo --- chroma.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/chroma.go b/chroma.go index d82309c..c019fda 100644 --- a/chroma.go +++ b/chroma.go @@ -395,7 +395,7 @@ type GetResults struct { func (r *GetResults) NextPage(ctx context.Context) (*GetResults, error) { if r.PageInfo == nil { - return nil, fmt.Errorf("no page info. Your Get must contain limit and offset options") + return nil, fmt.Errorf("no page info. Your Get must contain limit option") } r.PageInfo.QueryOptions = append(r.PageInfo.QueryOptions, types.WithOffset(r.PageInfo.Offset+r.PageInfo.Limit)) return r.collection.GetWithOptions(ctx, r.PageInfo.QueryOptions...) @@ -403,7 +403,7 @@ func (r *GetResults) NextPage(ctx context.Context) (*GetResults, error) { func (r *GetResults) PreviousPage(ctx context.Context) (*GetResults, error) { if r.PageInfo == nil { - return nil, fmt.Errorf("no page info. Your Get must contain limit and offset options") + return nil, fmt.Errorf("no page info. Your Get must contain limit options") } if r.PageInfo.Offset-r.PageInfo.Limit < 0 { return nil, fmt.Errorf("cannot go to previous page. Offset is less than 0") @@ -585,8 +585,9 @@ func (c *Collection) GetWithOptions(ctx context.Context, options ...types.Collec Embeddings: APIEmbeddingsToEmbeddings(cd.Embeddings), collection: c, } - // only add PageInfo when both limit and offset are set - if query.Limit != 0 && query.Offset != 0 { + // only add PageInfo when both limit (offset is assumed 0 if not provided) + // TODO can limit == collection count (that is an extra API call though) + if query.Limit != 0 { results.PageInfo = &types.PageInfo{ Limit: query.Limit, Offset: query.Offset, From 51d43b4a5c52c9285afb6410d74d9f12ae940317 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Wed, 10 Jul 2024 08:45:00 +0300 Subject: [PATCH 5/5] chore: Wip on result iterator --- chroma.go | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/chroma.go b/chroma.go index c019fda..006710f 100644 --- a/chroma.go +++ b/chroma.go @@ -385,12 +385,13 @@ func (c *Client) Version(ctx context.Context) (string, error) { } type GetResults struct { - Ids []string - Documents []string - Metadatas []map[string]interface{} - Embeddings []*types.Embedding - collection *Collection - PageInfo *types.PageInfo + Ids []string + Documents []string + Metadatas []map[string]interface{} + Embeddings []*types.Embedding + collection *Collection + PageInfo *types.PageInfo + RecordIterator chan *types.Record } func (r *GetResults) NextPage(ctx context.Context) (*GetResults, error) { @@ -594,6 +595,24 @@ func (c *Collection) GetWithOptions(ctx context.Context, options ...types.Collec QueryOptions: options, } } + + results.RecordIterator = make(chan *types.Record) + go func() { + count, err := c.Count(ctx) + if err != nil { + return + } + start := 0 + end := count + lim := 10 + if query.Limit != 0 { + lim = query.Limit + } + for i := start; i < end; i += lim { + c.GetWithOptions(ctx, types.WithLimit(10), types.WithOffset(i)) + } + close(ch) + }() return results, nil }