diff --git a/chroma.go b/chroma.go index f06b8bb..006710f 100644 --- a/chroma.go +++ b/chroma.go @@ -385,10 +385,32 @@ func (c *Client) Version(ctx context.Context) (string, error) { } type GetResults struct { - Ids []string - Documents []string - Metadatas []map[string]interface{} - Embeddings []*types.Embedding + Ids []string + Documents []string + Metadatas []map[string]interface{} + Embeddings []*types.Embedding + collection *Collection + PageInfo *types.PageInfo + RecordIterator chan *types.Record +} + +func (r *GetResults) NextPage(ctx context.Context) (*GetResults, error) { + if r.PageInfo == nil { + return nil, fmt.Errorf("no page info. Your Get must contain limit option") + } + r.PageInfo.QueryOptions = append(r.PageInfo.QueryOptions, types.WithOffset(r.PageInfo.Offset+r.PageInfo.Limit)) + return r.collection.GetWithOptions(ctx, r.PageInfo.QueryOptions...) +} + +func (r *GetResults) PreviousPage(ctx context.Context) (*GetResults, error) { + if r.PageInfo == nil { + return nil, fmt.Errorf("no page info. Your Get must contain limit options") + } + if r.PageInfo.Offset-r.PageInfo.Limit < 0 { + return nil, fmt.Errorf("cannot go to previous page. Offset is less than 0") + } + r.PageInfo.QueryOptions = append(r.PageInfo.QueryOptions, types.WithOffset(r.PageInfo.Offset-r.PageInfo.Limit)) + return r.collection.GetWithOptions(ctx, r.PageInfo.QueryOptions...) } type Collection struct { @@ -562,7 +584,35 @@ func (c *Collection) GetWithOptions(ctx context.Context, options ...types.Collec Documents: cd.Documents, Metadatas: cd.Metadatas, Embeddings: APIEmbeddingsToEmbeddings(cd.Embeddings), + collection: c, + } + // only add PageInfo when both limit (offset is assumed 0 if not provided) + // TODO can limit == collection count (that is an extra API call though) + if query.Limit != 0 { + results.PageInfo = &types.PageInfo{ + Limit: query.Limit, + Offset: query.Offset, + QueryOptions: options, + } } + + results.RecordIterator = make(chan *types.Record) + go func() { + count, err := c.Count(ctx) + if err != nil { + return + } + start := 0 + end := count + lim := 10 + if query.Limit != 0 { + lim = query.Limit + } + for i := start; i < end; i += lim { + c.GetWithOptions(ctx, types.WithLimit(10), types.WithOffset(i)) + } + close(ch) + }() return results, nil } diff --git a/test/chroma_client_test.go b/test/chroma_client_test.go index 036db30..dca262d 100644 --- a/test/chroma_client_test.go +++ b/test/chroma_client_test.go @@ -1123,4 +1123,124 @@ func Test_chroma_client(t *testing.T) { require.Equal(t, "test-collection", resp.Name, "Collection name should be test-collection") require.NotNil(t, resp.ID, "Collection id should not be nil") }) + + t.Run("Test Get with PageInfo", func(t *testing.T) { + collectionName := "test-collection" + metadata := map[string]interface{}{} + embeddingFunction := types.NewConsistentHashEmbeddingFunction() + _, errRest := client.Reset(context.Background()) + require.NoError(t, errRest) + newCollection, err := client.CreateCollection(context.Background(), collectionName, metadata, true, embeddingFunction, types.L2) + require.NoError(t, err) + require.NotNil(t, newCollection) + require.Equal(t, collectionName, newCollection.Name) + require.Equal(t, 2, len(newCollection.Metadata)) + // assert the metadata contains key embedding_function + require.Contains(t, chroma.GetStringTypeOfEmbeddingFunction(embeddingFunction), newCollection.Metadata["embedding_function"]) + docs := make([]string, 0, 100) + ids := make([]string, 0, 100) + for i := 1; i <= 100; i++ { + doc := fmt.Sprintf("Document %d content", i) + docs = append(docs, doc) + ids = append(ids, fmt.Sprintf("ID%d", i)) + } + col, addError := newCollection.Add(context.Background(), nil, nil, docs, ids) + require.NoError(t, addError) + results, err := col.GetWithOptions(context.Background(), types.WithOffset(0), types.WithLimit(5)) + require.NoError(t, err) + require.Len(t, results.Ids, 5) + newResults, err := results.NextPage(context.Background()) + require.NoError(t, err) + require.Len(t, newResults.Ids, 5) + }) + + t.Run("Test Get with PageInfo overflow", func(t *testing.T) { + collectionName := "test-collection" + metadata := map[string]interface{}{} + embeddingFunction := types.NewConsistentHashEmbeddingFunction() + _, errRest := client.Reset(context.Background()) + require.NoError(t, errRest) + newCollection, err := client.CreateCollection(context.Background(), collectionName, metadata, true, embeddingFunction, types.L2) + require.NoError(t, err) + require.NotNil(t, newCollection) + require.Equal(t, collectionName, newCollection.Name) + require.Equal(t, 2, len(newCollection.Metadata)) + // assert the metadata contains key embedding_function + require.Contains(t, chroma.GetStringTypeOfEmbeddingFunction(embeddingFunction), newCollection.Metadata["embedding_function"]) + docs := make([]string, 0, 5) + ids := make([]string, 0, 5) + for i := 1; i <= 5; i++ { + doc := fmt.Sprintf("Document %d content", i) + docs = append(docs, doc) + ids = append(ids, fmt.Sprintf("ID%d", i)) + } + col, addError := newCollection.Add(context.Background(), nil, nil, docs, ids) + require.NoError(t, addError) + results, err := col.GetWithOptions(context.Background(), types.WithOffset(0), types.WithLimit(5)) + require.NoError(t, err) + require.Len(t, results.Ids, 5) + newResults, err := results.NextPage(context.Background()) + require.NoError(t, err) + require.Len(t, newResults.Ids, 0) + }) + + t.Run("Test Get with PageInfo negative offset", func(t *testing.T) { + collectionName := "test-collection" + metadata := map[string]interface{}{} + embeddingFunction := types.NewConsistentHashEmbeddingFunction() + _, errRest := client.Reset(context.Background()) + require.NoError(t, errRest) + newCollection, err := client.CreateCollection(context.Background(), collectionName, metadata, true, embeddingFunction, types.L2) + require.NoError(t, err) + require.NotNil(t, newCollection) + require.Equal(t, collectionName, newCollection.Name) + require.Equal(t, 2, len(newCollection.Metadata)) + // assert the metadata contains key embedding_function + require.Contains(t, chroma.GetStringTypeOfEmbeddingFunction(embeddingFunction), newCollection.Metadata["embedding_function"]) + docs := make([]string, 0, 5) + ids := make([]string, 0, 5) + for i := 1; i <= 5; i++ { + doc := fmt.Sprintf("Document %d content", i) + docs = append(docs, doc) + ids = append(ids, fmt.Sprintf("ID%d", i)) + } + col, addError := newCollection.Add(context.Background(), nil, nil, docs, ids) + require.NoError(t, addError) + results, err := col.GetWithOptions(context.Background(), types.WithOffset(0), types.WithLimit(5)) + require.NoError(t, err) + require.Len(t, results.Ids, 5) + _, err = results.PreviousPage(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot go to previous page") + }) + + t.Run("Test Get with PageInfo nil without limit/offset", func(t *testing.T) { + collectionName := "test-collection" + metadata := map[string]interface{}{} + embeddingFunction := types.NewConsistentHashEmbeddingFunction() + _, errRest := client.Reset(context.Background()) + require.NoError(t, errRest) + newCollection, err := client.CreateCollection(context.Background(), collectionName, metadata, true, embeddingFunction, types.L2) + require.NoError(t, err) + require.NotNil(t, newCollection) + require.Equal(t, collectionName, newCollection.Name) + require.Equal(t, 2, len(newCollection.Metadata)) + // assert the metadata contains key embedding_function + require.Contains(t, chroma.GetStringTypeOfEmbeddingFunction(embeddingFunction), newCollection.Metadata["embedding_function"]) + docs := make([]string, 0, 5) + ids := make([]string, 0, 5) + for i := 1; i <= 5; i++ { + doc := fmt.Sprintf("Document %d content", i) + docs = append(docs, doc) + ids = append(ids, fmt.Sprintf("ID%d", i)) + } + col, addError := newCollection.Add(context.Background(), nil, nil, docs, ids) + require.NoError(t, addError) + results, err := col.GetWithOptions(context.Background()) + require.NoError(t, err) + require.Len(t, results.Ids, 5) + _, err = results.PreviousPage(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "no page info") + }) } diff --git a/types/types.go b/types/types.go index a0727bb..c507fcc 100644 --- a/types/types.go +++ b/types/types.go @@ -594,3 +594,9 @@ func (t *TokenAuthCredentialsProvider) Authenticate(config *openapi.Configuratio return fmt.Errorf("unsupported token header: %v", t.Header) } } + +type PageInfo struct { + Limit int32 + Offset int32 + QueryOptions []CollectionQueryOption +}