Skip to content

Commit

Permalink
feat: parent retriever/indexer
Browse files Browse the repository at this point in the history
  • Loading branch information
meguminnnnnnnnn committed Dec 26, 2024
1 parent cc39e05 commit 471c382
Show file tree
Hide file tree
Showing 4 changed files with 401 additions and 0 deletions.
113 changes: 113 additions & 0 deletions flow/indexer/parent/parent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package parent

import (
"context"
"fmt"

"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/schema"
)

type Config struct {
// Indexer specifies the original indexer used to create document index.
Indexer indexer.Indexer
// Transformer specifies the processor before creating document index, typically a splitter.
Transformer document.Transformer
// ParentIDKey specifies the key in the metadata of the sub-documents generated by the transformer to store the parent document ID.
ParentIDKey string

// SubIDGenerator specifies the method for generating a specified number of sub-document IDs based on the parent document ID. Use UUID by default.
SubIDGenerator func(ctx context.Context, parentID string, num int) ([]string, error)
}

func NewIndexer(ctx context.Context, config *Config) (indexer.Indexer, error) {
if config.Indexer == nil {
return nil, fmt.Errorf("indexer is empty")
}
if config.Transformer == nil {
return nil, fmt.Errorf("transformer is empty")
}
if config.SubIDGenerator == nil {
return nil, fmt.Errorf("sub id generator is empty")
}

return &parentIndexer{
indexer: config.Indexer,
transformer: config.Transformer,
parentIDKey: config.ParentIDKey,
subIDGenerator: config.SubIDGenerator,
}, nil
}

type parentIndexer struct {
indexer indexer.Indexer
transformer document.Transformer
parentIDKey string
subIDGenerator func(ctx context.Context, parentID string, num int) ([]string, error)
}

func (p *parentIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) ([]string, error) {
subDocs, err := p.transformer.Transform(ctx, docs)
if err != nil {
return nil, fmt.Errorf("transform docs fail: %w", err)
}
if len(subDocs) == 0 {
return nil, fmt.Errorf("doc transformer returned no documents")
}
currentID := subDocs[0].ID
startIdx := 0
for i, subDoc := range subDocs {
if subDoc.MetaData == nil {
subDoc.MetaData = make(map[string]interface{})
}
subDoc.MetaData[p.parentIDKey] = subDoc.ID

if subDoc.ID == currentID {
continue
}

// generate new doc id
subIDs, err := p.subIDGenerator(ctx, subDocs[startIdx].ID, i-startIdx)
if err != nil {
return nil, err
}
if len(subIDs) != i-startIdx {
return nil, fmt.Errorf("generated sub IDs' num is unexpected")
}
for j := startIdx; j < i; j++ {
subDocs[j].ID = subIDs[j-startIdx]
}
startIdx = i
currentID = subDoc.ID
}
// generate new doc id
subIDs, err := p.subIDGenerator(ctx, subDocs[startIdx].ID, len(subDocs)-startIdx)
if err != nil {
return nil, err
}
if len(subIDs) != len(subDocs)-startIdx {
return nil, fmt.Errorf("generated sub IDs' num is unexpected")
}
for j := startIdx; j < len(subDocs); j++ {
subDocs[j].ID = subIDs[j-startIdx]
}

return p.indexer.Store(ctx, subDocs, opts...)
}
121 changes: 121 additions & 0 deletions flow/indexer/parent/parent_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package parent

import (
"context"
"fmt"
"reflect"
"strconv"
"strings"
"testing"

"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/schema"
)

type testIndexer struct{}

func (t *testIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
ret := make([]string, len(docs))
for i, d := range docs {
ret[i] = d.ID
if !strings.HasPrefix(d.ID, d.MetaData["parent"].(string)) {
return nil, fmt.Errorf("invalid parent key")
}
}
return ret, nil
}

type testTransformer struct {
}

func (t *testTransformer) Transform(ctx context.Context, src []*schema.Document, opts ...document.TransformerOption) ([]*schema.Document, error) {
var ret []*schema.Document
for _, d := range src {
ret = append(ret, &schema.Document{
ID: d.ID,
Content: d.Content[:len(d.Content)/2],
MetaData: deepCopyMap(d.MetaData),
}, &schema.Document{
ID: d.ID,
Content: d.Content[len(d.Content)/2:],
MetaData: deepCopyMap(d.MetaData),
})
}
return ret, nil
}

func TestParentIndexer(t *testing.T) {
tests := []struct {
name string
config *Config
input []*schema.Document
want []string
}{
{
name: "success",
config: &Config{
Indexer: &testIndexer{},
Transformer: &testTransformer{},
ParentIDKey: "parent",
SubIDGenerator: func(ctx context.Context, parentID string, num int) ([]string, error) {
ret := make([]string, num)
for i := range ret {
ret[i] = parentID + strconv.Itoa(i)
}
return ret, nil
},
},
input: []*schema.Document{{
ID: "id",
Content: "1234567890",
MetaData: map[string]interface{}{},
}, {
ID: "ID",
Content: "0987654321",
MetaData: map[string]interface{}{},
}},
want: []string{"id0", "id1", "ID0", "ID1"},
},
}
ctx := context.Background()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
index, err := NewIndexer(ctx, tt.config)
if err != nil {
t.Fatal(err)
}
ret, err := index.Store(ctx, tt.input)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(ret, tt.want) {
t.Errorf("NewHeaderSplitter() got = %v, want %v", ret, tt.want)
}
})
}
}

func deepCopyMap(in map[string]interface{}) map[string]interface{} {
out := make(map[string]interface{})
for k, v := range in {
out[k] = v
}
return out
}
79 changes: 79 additions & 0 deletions flow/retriever/parent/parent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package parent

import (
"context"
"fmt"

"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
)

type Config struct {
// Retriever specifies the original retriever used to retrieve documents.
Retriever retriever.Retriever
// ParentIDKey specifies the key used in the sub-document metadata to store the parent document ID. Documents without this key will be removed from the recall results.
ParentIDKey string
// OrigDocGetter specifies the method for getting original documents by ids from the sub-document metadata.
OrigDocGetter func(ctx context.Context, ids []string) ([]*schema.Document, error)
}

func NewRetriever(ctx context.Context, config *Config) (retriever.Retriever, error) {
if config.Retriever == nil {
return nil, fmt.Errorf("retriever is required")
}
if config.OrigDocGetter == nil {
return nil, fmt.Errorf("orig doc getter is required")
}
return &parentRetriever{
retriever: config.Retriever,
parentIDKey: config.ParentIDKey,
origDocGetter: config.OrigDocGetter,
}, nil
}

type parentRetriever struct {
retriever retriever.Retriever
parentIDKey string
origDocGetter func(ctx context.Context, ids []string) ([]*schema.Document, error)
}

func (p *parentRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
subDocs, err := p.retriever.Retrieve(ctx, query, opts...)
if err != nil {
return nil, err
}
ids := make([]string, 0, len(subDocs))
for _, subDoc := range subDocs {
if k, ok := subDoc.MetaData[p.parentIDKey]; ok {
if s, okk := k.(string); okk && !inList(s, ids) {
ids = append(ids, s)
}
}
}
return p.origDocGetter(ctx, ids)
}

func inList(elem string, list []string) bool {
for _, v := range list {
if v == elem {
return true
}
}
return false
}
Loading

0 comments on commit 471c382

Please sign in to comment.