From 471c382d1cdf3e86ef834b90b319ee967f0fb3c3 Mon Sep 17 00:00:00 2001 From: Megumin Date: Mon, 23 Dec 2024 18:05:59 +0800 Subject: [PATCH] feat: parent retriever/indexer --- flow/indexer/parent/parent.go | 113 +++++++++++++++++++++++++ flow/indexer/parent/parent_test.go | 121 +++++++++++++++++++++++++++ flow/retriever/parent/parent.go | 79 +++++++++++++++++ flow/retriever/parent/parent_test.go | 88 +++++++++++++++++++ 4 files changed, 401 insertions(+) create mode 100644 flow/indexer/parent/parent.go create mode 100644 flow/indexer/parent/parent_test.go create mode 100644 flow/retriever/parent/parent.go create mode 100644 flow/retriever/parent/parent_test.go diff --git a/flow/indexer/parent/parent.go b/flow/indexer/parent/parent.go new file mode 100644 index 0000000..d50da06 --- /dev/null +++ b/flow/indexer/parent/parent.go @@ -0,0 +1,113 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package parent + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/schema" +) + +type Config struct { + // Indexer specifies the original indexer used to create document index. + Indexer indexer.Indexer + // Transformer specifies the processor before creating document index, typically a splitter. + Transformer document.Transformer + // ParentIDKey specifies the key in the metadata of the sub-documents generated by the transformer to store the parent document ID. + ParentIDKey string + + // SubIDGenerator specifies the method for generating a specified number of sub-document IDs based on the parent document ID. Use UUID by default. + SubIDGenerator func(ctx context.Context, parentID string, num int) ([]string, error) +} + +func NewIndexer(ctx context.Context, config *Config) (indexer.Indexer, error) { + if config.Indexer == nil { + return nil, fmt.Errorf("indexer is empty") + } + if config.Transformer == nil { + return nil, fmt.Errorf("transformer is empty") + } + if config.SubIDGenerator == nil { + return nil, fmt.Errorf("sub id generator is empty") + } + + return &parentIndexer{ + indexer: config.Indexer, + transformer: config.Transformer, + parentIDKey: config.ParentIDKey, + subIDGenerator: config.SubIDGenerator, + }, nil +} + +type parentIndexer struct { + indexer indexer.Indexer + transformer document.Transformer + parentIDKey string + subIDGenerator func(ctx context.Context, parentID string, num int) ([]string, error) +} + +func (p *parentIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) ([]string, error) { + subDocs, err := p.transformer.Transform(ctx, docs) + if err != nil { + return nil, fmt.Errorf("transform docs fail: %w", err) + } + if len(subDocs) == 0 { + return nil, fmt.Errorf("doc transformer returned no documents") + } + currentID := subDocs[0].ID + startIdx := 0 + for i, subDoc := range subDocs { + if subDoc.MetaData == nil { + subDoc.MetaData = make(map[string]interface{}) + } + subDoc.MetaData[p.parentIDKey] = subDoc.ID + + if subDoc.ID == currentID { + continue + } + + // generate new doc id + subIDs, err := p.subIDGenerator(ctx, subDocs[startIdx].ID, i-startIdx) + if err != nil { + return nil, err + } + if len(subIDs) != i-startIdx { + return nil, fmt.Errorf("generated sub IDs' num is unexpected") + } + for j := startIdx; j < i; j++ { + subDocs[j].ID = subIDs[j-startIdx] + } + startIdx = i + currentID = subDoc.ID + } + // generate new doc id + subIDs, err := p.subIDGenerator(ctx, subDocs[startIdx].ID, len(subDocs)-startIdx) + if err != nil { + return nil, err + } + if len(subIDs) != len(subDocs)-startIdx { + return nil, fmt.Errorf("generated sub IDs' num is unexpected") + } + for j := startIdx; j < len(subDocs); j++ { + subDocs[j].ID = subIDs[j-startIdx] + } + + return p.indexer.Store(ctx, subDocs, opts...) +} diff --git a/flow/indexer/parent/parent_test.go b/flow/indexer/parent/parent_test.go new file mode 100644 index 0000000..82b3621 --- /dev/null +++ b/flow/indexer/parent/parent_test.go @@ -0,0 +1,121 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package parent + +import ( + "context" + "fmt" + "reflect" + "strconv" + "strings" + "testing" + + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/schema" +) + +type testIndexer struct{} + +func (t *testIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) { + ret := make([]string, len(docs)) + for i, d := range docs { + ret[i] = d.ID + if !strings.HasPrefix(d.ID, d.MetaData["parent"].(string)) { + return nil, fmt.Errorf("invalid parent key") + } + } + return ret, nil +} + +type testTransformer struct { +} + +func (t *testTransformer) Transform(ctx context.Context, src []*schema.Document, opts ...document.TransformerOption) ([]*schema.Document, error) { + var ret []*schema.Document + for _, d := range src { + ret = append(ret, &schema.Document{ + ID: d.ID, + Content: d.Content[:len(d.Content)/2], + MetaData: deepCopyMap(d.MetaData), + }, &schema.Document{ + ID: d.ID, + Content: d.Content[len(d.Content)/2:], + MetaData: deepCopyMap(d.MetaData), + }) + } + return ret, nil +} + +func TestParentIndexer(t *testing.T) { + tests := []struct { + name string + config *Config + input []*schema.Document + want []string + }{ + { + name: "success", + config: &Config{ + Indexer: &testIndexer{}, + Transformer: &testTransformer{}, + ParentIDKey: "parent", + SubIDGenerator: func(ctx context.Context, parentID string, num int) ([]string, error) { + ret := make([]string, num) + for i := range ret { + ret[i] = parentID + strconv.Itoa(i) + } + return ret, nil + }, + }, + input: []*schema.Document{{ + ID: "id", + Content: "1234567890", + MetaData: map[string]interface{}{}, + }, { + ID: "ID", + Content: "0987654321", + MetaData: map[string]interface{}{}, + }}, + want: []string{"id0", "id1", "ID0", "ID1"}, + }, + } + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + index, err := NewIndexer(ctx, tt.config) + if err != nil { + t.Fatal(err) + } + ret, err := index.Store(ctx, tt.input) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(ret, tt.want) { + t.Errorf("NewHeaderSplitter() got = %v, want %v", ret, tt.want) + } + }) + } +} + +func deepCopyMap(in map[string]interface{}) map[string]interface{} { + out := make(map[string]interface{}) + for k, v := range in { + out[k] = v + } + return out +} diff --git a/flow/retriever/parent/parent.go b/flow/retriever/parent/parent.go new file mode 100644 index 0000000..34f2e86 --- /dev/null +++ b/flow/retriever/parent/parent.go @@ -0,0 +1,79 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package parent + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" +) + +type Config struct { + // Retriever specifies the original retriever used to retrieve documents. + Retriever retriever.Retriever + // ParentIDKey specifies the key used in the sub-document metadata to store the parent document ID. Documents without this key will be removed from the recall results. + ParentIDKey string + // OrigDocGetter specifies the method for getting original documents by ids from the sub-document metadata. + OrigDocGetter func(ctx context.Context, ids []string) ([]*schema.Document, error) +} + +func NewRetriever(ctx context.Context, config *Config) (retriever.Retriever, error) { + if config.Retriever == nil { + return nil, fmt.Errorf("retriever is required") + } + if config.OrigDocGetter == nil { + return nil, fmt.Errorf("orig doc getter is required") + } + return &parentRetriever{ + retriever: config.Retriever, + parentIDKey: config.ParentIDKey, + origDocGetter: config.OrigDocGetter, + }, nil +} + +type parentRetriever struct { + retriever retriever.Retriever + parentIDKey string + origDocGetter func(ctx context.Context, ids []string) ([]*schema.Document, error) +} + +func (p *parentRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { + subDocs, err := p.retriever.Retrieve(ctx, query, opts...) + if err != nil { + return nil, err + } + ids := make([]string, 0, len(subDocs)) + for _, subDoc := range subDocs { + if k, ok := subDoc.MetaData[p.parentIDKey]; ok { + if s, okk := k.(string); okk && !inList(s, ids) { + ids = append(ids, s) + } + } + } + return p.origDocGetter(ctx, ids) +} + +func inList(elem string, list []string) bool { + for _, v := range list { + if v == elem { + return true + } + } + return false +} diff --git a/flow/retriever/parent/parent_test.go b/flow/retriever/parent/parent_test.go new file mode 100644 index 0000000..6487e25 --- /dev/null +++ b/flow/retriever/parent/parent_test.go @@ -0,0 +1,88 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package parent + +import ( + "context" + "reflect" + "testing" + + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" +) + +type testRetriever struct{} + +func (t *testRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { + ret := make([]*schema.Document, 0) + for i := range query { + ret = append(ret, &schema.Document{ + ID: "", + Content: "", + MetaData: map[string]interface{}{ + "parent": query[i : i+1], + }, + }) + } + return ret, nil +} + +func TestParentRetriever(t *testing.T) { + tests := []struct { + name string + config *Config + input string + want []*schema.Document + }{ + { + name: "success", + config: &Config{ + Retriever: &testRetriever{}, + ParentIDKey: "parent", + OrigDocGetter: func(ctx context.Context, ids []string) ([]*schema.Document, error) { + var ret []*schema.Document + for i := range ids { + ret = append(ret, &schema.Document{ID: ids[i]}) + } + return ret, nil + }, + }, + input: "123233", + want: []*schema.Document{ + {ID: "1"}, + {ID: "2"}, + {ID: "3"}, + }, + }, + } + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, err := NewRetriever(ctx, tt.config) + if err != nil { + t.Fatal(err) + } + ret, err := r.Retrieve(ctx, tt.input) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(ret, tt.want) { + t.Errorf("got %v, want %v", ret, tt.want) + } + }) + } +}