generated from cloudwego/.github
-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cc39e05
commit 471c382
Showing
4 changed files
with
401 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
/* | ||
* Copyright 2024 CloudWeGo Authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package parent | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
|
||
"github.com/cloudwego/eino/components/document" | ||
"github.com/cloudwego/eino/components/indexer" | ||
"github.com/cloudwego/eino/schema" | ||
) | ||
|
||
type Config struct { | ||
// Indexer specifies the original indexer used to create document index. | ||
Indexer indexer.Indexer | ||
// Transformer specifies the processor before creating document index, typically a splitter. | ||
Transformer document.Transformer | ||
// ParentIDKey specifies the key in the metadata of the sub-documents generated by the transformer to store the parent document ID. | ||
ParentIDKey string | ||
|
||
// SubIDGenerator specifies the method for generating a specified number of sub-document IDs based on the parent document ID. Use UUID by default. | ||
SubIDGenerator func(ctx context.Context, parentID string, num int) ([]string, error) | ||
} | ||
|
||
func NewIndexer(ctx context.Context, config *Config) (indexer.Indexer, error) { | ||
if config.Indexer == nil { | ||
return nil, fmt.Errorf("indexer is empty") | ||
} | ||
if config.Transformer == nil { | ||
return nil, fmt.Errorf("transformer is empty") | ||
} | ||
if config.SubIDGenerator == nil { | ||
return nil, fmt.Errorf("sub id generator is empty") | ||
} | ||
|
||
return &parentIndexer{ | ||
indexer: config.Indexer, | ||
transformer: config.Transformer, | ||
parentIDKey: config.ParentIDKey, | ||
subIDGenerator: config.SubIDGenerator, | ||
}, nil | ||
} | ||
|
||
type parentIndexer struct { | ||
indexer indexer.Indexer | ||
transformer document.Transformer | ||
parentIDKey string | ||
subIDGenerator func(ctx context.Context, parentID string, num int) ([]string, error) | ||
} | ||
|
||
func (p *parentIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) ([]string, error) { | ||
subDocs, err := p.transformer.Transform(ctx, docs) | ||
if err != nil { | ||
return nil, fmt.Errorf("transform docs fail: %w", err) | ||
} | ||
if len(subDocs) == 0 { | ||
return nil, fmt.Errorf("doc transformer returned no documents") | ||
} | ||
currentID := subDocs[0].ID | ||
startIdx := 0 | ||
for i, subDoc := range subDocs { | ||
if subDoc.MetaData == nil { | ||
subDoc.MetaData = make(map[string]interface{}) | ||
} | ||
subDoc.MetaData[p.parentIDKey] = subDoc.ID | ||
|
||
if subDoc.ID == currentID { | ||
continue | ||
} | ||
|
||
// generate new doc id | ||
subIDs, err := p.subIDGenerator(ctx, subDocs[startIdx].ID, i-startIdx) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if len(subIDs) != i-startIdx { | ||
return nil, fmt.Errorf("generated sub IDs' num is unexpected") | ||
} | ||
for j := startIdx; j < i; j++ { | ||
subDocs[j].ID = subIDs[j-startIdx] | ||
} | ||
startIdx = i | ||
currentID = subDoc.ID | ||
} | ||
// generate new doc id | ||
subIDs, err := p.subIDGenerator(ctx, subDocs[startIdx].ID, len(subDocs)-startIdx) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if len(subIDs) != len(subDocs)-startIdx { | ||
return nil, fmt.Errorf("generated sub IDs' num is unexpected") | ||
} | ||
for j := startIdx; j < len(subDocs); j++ { | ||
subDocs[j].ID = subIDs[j-startIdx] | ||
} | ||
|
||
return p.indexer.Store(ctx, subDocs, opts...) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
/* | ||
* Copyright 2024 CloudWeGo Authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package parent | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"reflect" | ||
"strconv" | ||
"strings" | ||
"testing" | ||
|
||
"github.com/cloudwego/eino/components/document" | ||
"github.com/cloudwego/eino/components/indexer" | ||
"github.com/cloudwego/eino/schema" | ||
) | ||
|
||
type testIndexer struct{} | ||
|
||
func (t *testIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) { | ||
ret := make([]string, len(docs)) | ||
for i, d := range docs { | ||
ret[i] = d.ID | ||
if !strings.HasPrefix(d.ID, d.MetaData["parent"].(string)) { | ||
return nil, fmt.Errorf("invalid parent key") | ||
} | ||
} | ||
return ret, nil | ||
} | ||
|
||
type testTransformer struct { | ||
} | ||
|
||
func (t *testTransformer) Transform(ctx context.Context, src []*schema.Document, opts ...document.TransformerOption) ([]*schema.Document, error) { | ||
var ret []*schema.Document | ||
for _, d := range src { | ||
ret = append(ret, &schema.Document{ | ||
ID: d.ID, | ||
Content: d.Content[:len(d.Content)/2], | ||
MetaData: deepCopyMap(d.MetaData), | ||
}, &schema.Document{ | ||
ID: d.ID, | ||
Content: d.Content[len(d.Content)/2:], | ||
MetaData: deepCopyMap(d.MetaData), | ||
}) | ||
} | ||
return ret, nil | ||
} | ||
|
||
func TestParentIndexer(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
config *Config | ||
input []*schema.Document | ||
want []string | ||
}{ | ||
{ | ||
name: "success", | ||
config: &Config{ | ||
Indexer: &testIndexer{}, | ||
Transformer: &testTransformer{}, | ||
ParentIDKey: "parent", | ||
SubIDGenerator: func(ctx context.Context, parentID string, num int) ([]string, error) { | ||
ret := make([]string, num) | ||
for i := range ret { | ||
ret[i] = parentID + strconv.Itoa(i) | ||
} | ||
return ret, nil | ||
}, | ||
}, | ||
input: []*schema.Document{{ | ||
ID: "id", | ||
Content: "1234567890", | ||
MetaData: map[string]interface{}{}, | ||
}, { | ||
ID: "ID", | ||
Content: "0987654321", | ||
MetaData: map[string]interface{}{}, | ||
}}, | ||
want: []string{"id0", "id1", "ID0", "ID1"}, | ||
}, | ||
} | ||
ctx := context.Background() | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
index, err := NewIndexer(ctx, tt.config) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
ret, err := index.Store(ctx, tt.input) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
if !reflect.DeepEqual(ret, tt.want) { | ||
t.Errorf("NewHeaderSplitter() got = %v, want %v", ret, tt.want) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func deepCopyMap(in map[string]interface{}) map[string]interface{} { | ||
out := make(map[string]interface{}) | ||
for k, v := range in { | ||
out[k] = v | ||
} | ||
return out | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
/* | ||
* Copyright 2024 CloudWeGo Authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package parent | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
|
||
"github.com/cloudwego/eino/components/retriever" | ||
"github.com/cloudwego/eino/schema" | ||
) | ||
|
||
type Config struct { | ||
// Retriever specifies the original retriever used to retrieve documents. | ||
Retriever retriever.Retriever | ||
// ParentIDKey specifies the key used in the sub-document metadata to store the parent document ID. Documents without this key will be removed from the recall results. | ||
ParentIDKey string | ||
// OrigDocGetter specifies the method for getting original documents by ids from the sub-document metadata. | ||
OrigDocGetter func(ctx context.Context, ids []string) ([]*schema.Document, error) | ||
} | ||
|
||
func NewRetriever(ctx context.Context, config *Config) (retriever.Retriever, error) { | ||
if config.Retriever == nil { | ||
return nil, fmt.Errorf("retriever is required") | ||
} | ||
if config.OrigDocGetter == nil { | ||
return nil, fmt.Errorf("orig doc getter is required") | ||
} | ||
return &parentRetriever{ | ||
retriever: config.Retriever, | ||
parentIDKey: config.ParentIDKey, | ||
origDocGetter: config.OrigDocGetter, | ||
}, nil | ||
} | ||
|
||
type parentRetriever struct { | ||
retriever retriever.Retriever | ||
parentIDKey string | ||
origDocGetter func(ctx context.Context, ids []string) ([]*schema.Document, error) | ||
} | ||
|
||
func (p *parentRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { | ||
subDocs, err := p.retriever.Retrieve(ctx, query, opts...) | ||
if err != nil { | ||
return nil, err | ||
} | ||
ids := make([]string, 0, len(subDocs)) | ||
for _, subDoc := range subDocs { | ||
if k, ok := subDoc.MetaData[p.parentIDKey]; ok { | ||
if s, okk := k.(string); okk && !inList(s, ids) { | ||
ids = append(ids, s) | ||
} | ||
} | ||
} | ||
return p.origDocGetter(ctx, ids) | ||
} | ||
|
||
func inList(elem string, list []string) bool { | ||
for _, v := range list { | ||
if v == elem { | ||
return true | ||
} | ||
} | ||
return false | ||
} |
Oops, something went wrong.