Skip to content

Commit

Permalink
refactor: es indexer field mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
N3kox committed Jan 14, 2025
1 parent 3a78abf commit 0092596
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 53 deletions.
76 changes: 56 additions & 20 deletions components/indexer/es8/indexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,41 @@ import (
"encoding/json"
"fmt"

"github.com/elastic/go-elasticsearch/v8"
"github.com/elastic/go-elasticsearch/v8/esutil"

"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/embedding"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/schema"
"github.com/elastic/go-elasticsearch/v8"
"github.com/elastic/go-elasticsearch/v8/esutil"
)

type IndexerConfig struct {
ESConfig elasticsearch.Config `json:"es_config"`
Index string `json:"index"`
// BatchSize controls max texts size for embedding
BatchSize int `json:"batch_size"`
// FieldMapping supports customize es fields from eino document, returns:
// needEmbeddingFields will be embedded by Embedding firstly, then join fields with its keys,
// and joined fields will be saved as bulk item.
FieldMapping func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error)
// FieldMapping supports customize es fields from eino document.
// Each key - FieldValue.Value from field2Value will be saved, and
// vector of FieldValue.Value will be saved if FieldValue.EmbedKey is not empty.
DocumentToFields func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error)
// Embedding vectorization method, must provide in two cases
// 1. VectorFields contains fields except doc Content
// 2. VectorFields contains doc Content and vector not provided in doc extra (see Document.Vector method)
Embedding embedding.Embedder
}

type FieldValue struct {
// Value original Value
Value any
// EmbedKey if set, Value will be vectorized and saved to es.
// If Stringify method is provided, Embedding input text will be Stringify(Value).
// If Stringify method not set, retriever will try to assert Value as string.
EmbedKey string
// Stringify converts Value to string
Stringify func(val any) (string, error)
}

type Indexer struct {
client *elasticsearch.Client
config *IndexerConfig
Expand All @@ -58,8 +68,8 @@ func NewIndexer(_ context.Context, conf *IndexerConfig) (*Indexer, error) {
return nil, fmt.Errorf("[NewIndexer] new es client failed, %w", err)
}

if conf.FieldMapping == nil {
return nil, fmt.Errorf("[NewIndexer] field mapping method not provided")
if conf.DocumentToFields == nil {
return nil, fmt.Errorf("[NewIndexer] DocumentToFields method not provided")
}

if conf.BatchSize == 0 {
Expand Down Expand Up @@ -158,34 +168,60 @@ func (i *Indexer) bulkAdd(ctx context.Context, docs []*schema.Document, options

for idx := range docs {
doc := docs[idx]
fields, needEmbeddingFields, err := i.config.FieldMapping(ctx, doc)
fields, err := i.config.DocumentToFields(ctx, doc)
if err != nil {
return fmt.Errorf("[bulkAdd] FieldMapping failed, %w", err)
}
if fields == nil {
fields = make(map[string]any)

rawFields := make(map[string]any)
embSize := 0
for k, v := range fields {
rawFields[k] = v.Value
if v.EmbedKey != "" {
embSize++
}
}

if len(needEmbeddingFields) > i.config.BatchSize {
if embSize > i.config.BatchSize {
return fmt.Errorf("[bulkAdd] needEmbeddingFields length over batch size, batch size=%d, got size=%d",
i.config.BatchSize, len(needEmbeddingFields))
i.config.BatchSize, embSize)
}

if len(texts)+len(needEmbeddingFields) > i.config.BatchSize {
if len(texts)+embSize > i.config.BatchSize {
if err = embAndAdd(); err != nil {
return err
}
}

key2Idx := make(map[string]int, len(needEmbeddingFields))
for k, text := range needEmbeddingFields {
key2Idx[k] = len(texts)
texts = append(texts, text)
key2Idx := make(map[string]int, embSize)
for k, v := range fields {
if v.EmbedKey != "" {
if v.EmbedKey == k {
return fmt.Errorf("[bulkAdd] duplicate key for value and vector, field=%s", k)
}

var text string
if v.Stringify != nil {
text, err = v.Stringify(v.Value)
if err != nil {
return err
}
} else {
var ok bool
text, ok = v.Value.(string)
if !ok {
return fmt.Errorf("[bulkAdd] assert value as string failed, key=%s, emb_key=%s", k, v.EmbedKey)
}
}

key2Idx[v.EmbedKey] = len(texts)
texts = append(texts, text)
}
}

tuples = append(tuples, tuple{
id: doc.ID,
fields: fields,
fields: rawFields,
key2Idx: key2Idx,
})
}
Expand Down
79 changes: 46 additions & 33 deletions components/indexer/es8/indexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ func TestBulkAdd(t *testing.T) {
i := &Indexer{
config: &IndexerConfig{
Index: "mock_index",
FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) {
return nil, nil, nil
DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error) {
return nil, nil
},
},
}
Expand All @@ -65,8 +65,8 @@ func TestBulkAdd(t *testing.T) {
i := &Indexer{
config: &IndexerConfig{
Index: "mock_index",
FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) {
return nil, nil, mockErr
DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error) {
return nil, mockErr
},
},
}
Expand All @@ -82,9 +82,10 @@ func TestBulkAdd(t *testing.T) {
config: &IndexerConfig{
Index: "mock_index",
BatchSize: 1,
FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) {
return nil, map[string]string{
"k1": "v1", "k2": "v2",
DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error) {
return map[string]FieldValue{
"k1": {Value: "v1", EmbedKey: "k"},
"k2": {Value: "v2", EmbedKey: "kk"},
}, nil
},
},
Expand All @@ -101,12 +102,15 @@ func TestBulkAdd(t *testing.T) {
config: &IndexerConfig{
Index: "mock_index",
BatchSize: 2,
FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) {
return map[string]any{
"k0": "v0", "k1": "v1", "k3": 123,
}, map[string]string{
"k1": "v1", "k2": "v2",
}, nil
DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error) {
return map[string]FieldValue{
"k0": {Value: "v0"},
"k1": {Value: "v1", EmbedKey: "vk1"},
"k2": {Value: 222, EmbedKey: "vk2", Stringify: func(val any) (string, error) {
return "222", nil
}},
"k3": {Value: 123},
}, nil
},
},
}
Expand All @@ -123,12 +127,15 @@ func TestBulkAdd(t *testing.T) {
config: &IndexerConfig{
Index: "mock_index",
BatchSize: 2,
FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) {
return map[string]any{
"k0": "v0", "k1": "v1", "k3": 123,
}, map[string]string{
"k1": "v1", "k2": "v2",
}, nil
DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error) {
return map[string]FieldValue{
"k0": {Value: "v0"},
"k1": {Value: "v1", EmbedKey: "vk1"},
"k2": {Value: 222, EmbedKey: "vk2", Stringify: func(val any) (string, error) {
return "222", nil
}},
"k3": {Value: 123},
}, nil
},
},
}
Expand All @@ -144,12 +151,15 @@ func TestBulkAdd(t *testing.T) {
config: &IndexerConfig{
Index: "mock_index",
BatchSize: 2,
FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) {
return map[string]any{
"k0": "v0", "k1": "v1", "k3": 123,
}, map[string]string{
"k1": "v1", "k2": "v2",
}, nil
DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error) {
return map[string]FieldValue{
"k0": {Value: "v0"},
"k1": {Value: "v1", EmbedKey: "vk1"},
"k2": {Value: 222, EmbedKey: "vk2", Stringify: func(val any) (string, error) {
return "222", nil
}},
"k3": {Value: 123},
}, nil
},
},
}
Expand All @@ -172,12 +182,13 @@ func TestBulkAdd(t *testing.T) {
config: &IndexerConfig{
Index: "mock_index",
BatchSize: 2,
FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) {
return map[string]any{
"k0": doc.Content, "k1": "v1", "k3": 123,
}, map[string]string{
"k1": "v1", "k2": "v2",
}, nil
DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error) {
return map[string]FieldValue{
"k0": {Value: doc.Content},
"k1": {Value: "v1", EmbedKey: "vk1"},
"k2": {Value: 222, EmbedKey: "vk2", Stringify: func(val any) (string, error) { return "222", nil }},
"k3": {Value: 123},
}, nil
},
},
}
Expand All @@ -194,9 +205,11 @@ func TestBulkAdd(t *testing.T) {
var mp map[string]interface{}
convey.So(json.Unmarshal(b, &mp), convey.ShouldBeNil)
convey.So(mp["k0"], convey.ShouldEqual, doc.Content)
convey.So(mp["k1"], convey.ShouldEqual, []any{2.1})
convey.So(mp["k2"], convey.ShouldEqual, []any{2.1})
convey.So(mp["k1"], convey.ShouldEqual, "v1")
convey.So(mp["k2"], convey.ShouldEqual, 222)
convey.So(mp["k3"], convey.ShouldEqual, 123)
convey.So(mp["vk1"], convey.ShouldEqual, []any{2.1})
convey.So(mp["vk2"], convey.ShouldEqual, []any{2.1})
}
})
})
Expand Down

0 comments on commit 0092596

Please sign in to comment.