Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add es8 indexer&retriever #41

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
23 changes: 23 additions & 0 deletions components/indexer/es8/consts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package es8
BytePender marked this conversation as resolved.
Show resolved Hide resolved

const typ = "ElasticSearch8"

const (
defaultBatchSize = 5
)
52 changes: 52 additions & 0 deletions components/indexer/es8/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
module github.com/cloudwego/eino-ext/components/indexer/es8

go 1.22

require (
github.com/bytedance/mockey v1.2.13
github.com/cloudwego/eino v0.3.6
github.com/elastic/go-elasticsearch/v8 v8.16.0
github.com/smartystreets/goconvey v1.8.1
)

require (
github.com/bytedance/sonic v1.12.2 // indirect
github.com/bytedance/sonic/loader v0.2.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/elastic/elastic-transport-go/v8 v8.6.0 // indirect
github.com/getkin/kin-openapi v0.118.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/jsonpointer v0.19.5 // indirect
github.com/go-openapi/swag v0.19.5 // indirect
github.com/goph/emperror v0.17.2 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/invopop/yaml v0.1.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/jtolds/gls v4.20.0+incompatible // indirect
github.com/klauspost/cpuid/v2 v2.0.9 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
github.com/nikolalohinski/gonja v1.5.3 // indirect
github.com/pelletier/go-toml/v2 v2.0.9 // indirect
github.com/perimeterx/marshmallow v1.1.4 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect
github.com/smarty/assertions v1.15.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/yargevad/filepathx v1.0.0 // indirect
go.opentelemetry.io/otel v1.28.0 // indirect
go.opentelemetry.io/otel/metric v1.28.0 // indirect
go.opentelemetry.io/otel/trace v1.28.0 // indirect
golang.org/x/arch v0.11.0 // indirect
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect
golang.org/x/sys v0.26.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
179 changes: 179 additions & 0 deletions components/indexer/es8/go.sum

Large diffs are not rendered by default.

264 changes: 264 additions & 0 deletions components/indexer/es8/indexer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package es8
BytePender marked this conversation as resolved.
Show resolved Hide resolved

import (
"bytes"
"context"
"encoding/json"
"fmt"

"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/embedding"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/schema"
"github.com/elastic/go-elasticsearch/v8"
"github.com/elastic/go-elasticsearch/v8/esutil"
)

type IndexerConfig struct {
ESConfig elasticsearch.Config `json:"es_config"`
Index string `json:"index"`
// BatchSize controls max texts size for embedding
BatchSize int `json:"batch_size"`
// FieldMapping supports customize es fields from eino document.
// Each key - FieldValue.Value from field2Value will be saved, and
// vector of FieldValue.Value will be saved if FieldValue.EmbedKey is not empty.
DocumentToFields func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error)
// Embedding vectorization method, must provide in two cases
// 1. VectorFields contains fields except doc Content
// 2. VectorFields contains doc Content and vector not provided in doc extra (see Document.Vector method)
Embedding embedding.Embedder
}

type FieldValue struct {
// Value original Value
Value any
// EmbedKey if set, Value will be vectorized and saved to es.
// If Stringify method is provided, Embedding input text will be Stringify(Value).
// If Stringify method not set, retriever will try to assert Value as string.
EmbedKey string
// Stringify converts Value to string
Stringify func(val any) (string, error)
}

type Indexer struct {
client *elasticsearch.Client
config *IndexerConfig
}

func NewIndexer(_ context.Context, conf *IndexerConfig) (*Indexer, error) {
client, err := elasticsearch.NewClient(conf.ESConfig)
if err != nil {
return nil, fmt.Errorf("[NewIndexer] new es client failed, %w", err)
}

if conf.DocumentToFields == nil {
return nil, fmt.Errorf("[NewIndexer] DocumentToFields method not provided")
}

if conf.BatchSize == 0 {
conf.BatchSize = defaultBatchSize
}

return &Indexer{
client: client,
config: conf,
}, nil
}

func (i *Indexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
defer func() {
if err != nil {
callbacks.OnError(ctx, err)
}
}()

ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs})

options := indexer.GetCommonOptions(&indexer.Options{
Embedding: i.config.Embedding,
}, opts...)

if err = i.bulkAdd(ctx, docs, options); err != nil {
return nil, err
}

ids = iter(docs, func(t *schema.Document) string { return t.ID })

callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: ids})

return ids, nil
}

func (i *Indexer) bulkAdd(ctx context.Context, docs []*schema.Document, options *indexer.Options) error {
emb := options.Embedding
bi, err := esutil.NewBulkIndexer(esutil.BulkIndexerConfig{
Index: i.config.Index,
Client: i.client,
})
if err != nil {
return err
}

var (
tuples []tuple
texts []string
)

embAndAdd := func() error {
var vectors [][]float64

if len(texts) > 0 {
if emb == nil {
return fmt.Errorf("[bulkAdd] embedding method not provided")
}

vectors, err = emb.EmbedStrings(i.makeEmbeddingCtx(ctx, emb), texts)
if err != nil {
return fmt.Errorf("[bulkAdd] embedding failed, %w", err)
}

if len(vectors) != len(texts) {
return fmt.Errorf("[bulkAdd] invalid vector length, expected=%d, got=%d", len(texts), len(vectors))
}
}

for _, t := range tuples {
fields := t.fields
for k, idx := range t.key2Idx {
fields[k] = vectors[idx]
}

b, err := json.Marshal(fields)
if err != nil {
return fmt.Errorf("[bulkAdd] marshal bulk item failed, %w", err)
}

if err = bi.Add(ctx, esutil.BulkIndexerItem{
Index: i.config.Index,
Action: "index",
DocumentID: t.id,
Body: bytes.NewReader(b),
}); err != nil {
return err
}
}

tuples = tuples[:0]
texts = texts[:0]

return nil
}

for idx := range docs {
doc := docs[idx]
fields, err := i.config.DocumentToFields(ctx, doc)
if err != nil {
return fmt.Errorf("[bulkAdd] FieldMapping failed, %w", err)
}

rawFields := make(map[string]any)
embSize := 0
for k, v := range fields {
rawFields[k] = v.Value
if v.EmbedKey != "" {
embSize++
}
}

if embSize > i.config.BatchSize {
return fmt.Errorf("[bulkAdd] needEmbeddingFields length over batch size, batch size=%d, got size=%d",
i.config.BatchSize, embSize)
}

if len(texts)+embSize > i.config.BatchSize {
if err = embAndAdd(); err != nil {
return err
}
}

key2Idx := make(map[string]int, embSize)
for k, v := range fields {
if v.EmbedKey != "" {
if v.EmbedKey == k {
return fmt.Errorf("[bulkAdd] duplicate key for value and vector, field=%s", k)
}

var text string
if v.Stringify != nil {
text, err = v.Stringify(v.Value)
if err != nil {
return err
}
} else {
var ok bool
text, ok = v.Value.(string)
if !ok {
return fmt.Errorf("[bulkAdd] assert value as string failed, key=%s, emb_key=%s", k, v.EmbedKey)
}
}

key2Idx[v.EmbedKey] = len(texts)
texts = append(texts, text)
}
}

tuples = append(tuples, tuple{
id: doc.ID,
fields: rawFields,
key2Idx: key2Idx,
})
}

if len(tuples) > 0 {
if err = embAndAdd(); err != nil {
return err
}
}

return bi.Close(ctx)
}

func (i *Indexer) makeEmbeddingCtx(ctx context.Context, emb embedding.Embedder) context.Context {
runInfo := &callbacks.RunInfo{
Component: components.ComponentOfEmbedding,
}

if embType, ok := components.GetType(emb); ok {
runInfo.Type = embType
}

runInfo.Name = runInfo.Type + string(runInfo.Component)

return callbacks.ReuseHandlers(ctx, runInfo)
}

func (i *Indexer) GetType() string {
return typ
}

func (i *Indexer) IsCallbacksEnabled() bool {
return true
}

type tuple struct {
id string
fields map[string]any
key2Idx map[string]int
}
Loading
Loading