diff --git a/components/indexer/es8/field_mapping/consts.go b/components/indexer/es8/field_mapping/consts.go deleted file mode 100644 index abe46ff..0000000 --- a/components/indexer/es8/field_mapping/consts.go +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package field_mapping - -const DocFieldNameContent = "eino_doc_content" diff --git a/components/indexer/es8/field_mapping/field_mapping.go b/components/indexer/es8/field_mapping/field_mapping.go deleted file mode 100644 index 259e7bd..0000000 --- a/components/indexer/es8/field_mapping/field_mapping.go +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package field_mapping - -import ( - "fmt" - - "github.com/cloudwego/eino/schema" - - "github.com/cloudwego/eino-ext/components/indexer/es8/internal" -) - -// SetExtraDataFields set data fields for es -func SetExtraDataFields(doc *schema.Document, fields map[string]interface{}) { - if doc == nil { - return - } - - if doc.MetaData == nil { - doc.MetaData = make(map[string]any) - } - - doc.MetaData[internal.DocExtraKeyEsFields] = fields -} - -// GetExtraDataFields get data fields from *schema.Document -func GetExtraDataFields(doc *schema.Document) (fields map[string]interface{}, ok bool) { - if doc == nil || doc.MetaData == nil { - return nil, false - } - - fields, ok = doc.MetaData[internal.DocExtraKeyEsFields].(map[string]interface{}) - - return fields, ok -} - -// DefaultFieldKV build default names by fieldName -// docFieldName should be DocFieldNameContent or key got from GetExtraDataFields -func DefaultFieldKV(docFieldName FieldName) FieldKV { - return FieldKV{ - FieldNameVector: FieldName(fmt.Sprintf("vector_%s", docFieldName)), - FieldName: docFieldName, - } -} - -type FieldKV struct { - // FieldNameVector vector field name (if needed) - FieldNameVector FieldName `json:"field_name_vector,omitempty"` - // FieldName field name - FieldName FieldName `json:"field_name,omitempty"` -} - -type FieldName string - -func (v FieldName) Find(doc *schema.Document) (string, bool) { - if v == DocFieldNameContent { - return doc.Content, true - } - - kvs, ok := GetExtraDataFields(doc) - if !ok { - return "", false - } - - s, ok := kvs[string(v)].(string) - return s, ok -} diff --git a/components/indexer/es8/go.mod b/components/indexer/es8/go.mod index a009e88..a9547e7 100644 --- a/components/indexer/es8/go.mod +++ b/components/indexer/es8/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bytedance/mockey v1.2.13 - github.com/cloudwego/eino v0.3.5 + github.com/cloudwego/eino v0.3.6 github.com/elastic/go-elasticsearch/v8 v8.16.0 github.com/smartystreets/goconvey v1.8.1 ) diff --git a/components/indexer/es8/go.sum b/components/indexer/es8/go.sum index e701f60..889b85b 100644 --- a/components/indexer/es8/go.sum +++ b/components/indexer/es8/go.sum @@ -13,8 +13,8 @@ github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4 github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= -github.com/cloudwego/eino v0.3.5 h1:9PkAOX/phFifrGXkfl4L9rdecxOQJBJY1FtZqF4bz3c= -github.com/cloudwego/eino v0.3.5/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo= +github.com/cloudwego/eino v0.3.6 h1:3yfdKKxMVWefdOyGXHuqUMM5cc9iioijj2mpPsDZKIg= +github.com/cloudwego/eino v0.3.6/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/components/indexer/es8/indexer.go b/components/indexer/es8/indexer.go index a826bfe..ddfebb0 100644 --- a/components/indexer/es8/indexer.go +++ b/components/indexer/es8/indexer.go @@ -30,8 +30,6 @@ import ( "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/schema" - - "github.com/cloudwego/eino-ext/components/indexer/es8/field_mapping" ) type IndexerConfig struct { @@ -39,8 +37,10 @@ type IndexerConfig struct { Index string `json:"index"` BatchSize int `json:"batch_size"` - // VectorFields dense_vector field mappings - VectorFields []field_mapping.FieldKV `json:"vector_fields"` + // FieldMapping supports customize es fields from eino document, returns: + // needEmbeddingFields will be embedded by Embedding firstly, then join fields with its keys, + // and joined fields will be saved as bulk item. + FieldMapping func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) // Embedding vectorization method, must provide in two cases // 1. VectorFields contains fields except doc Content // 2. VectorFields contains doc Content and vector not provided in doc extra (see Document.Vector method) @@ -58,13 +58,8 @@ func NewIndexer(_ context.Context, conf *IndexerConfig) (*Indexer, error) { return nil, fmt.Errorf("[NewIndexer] new es client failed, %w", err) } - if conf.Embedding == nil { - for _, kv := range conf.VectorFields { - if kv.FieldName != field_mapping.DocFieldNameContent { - return nil, fmt.Errorf("[NewIndexer] Embedding not provided in config, but field kv[%s]-[%s] requires", - kv.FieldNameVector, kv.FieldName) - } - } + if conf.FieldMapping == nil { + return nil, fmt.Errorf("[NewIndexer] field mapping method not provided") } if conf.BatchSize == 0 { @@ -99,13 +94,7 @@ func (i *Indexer) Store(ctx context.Context, docs []*schema.Document, opts ...in } for _, slice := range chunk(docs, i.config.BatchSize) { - var items []esutil.BulkIndexerItem - - if len(i.config.VectorFields) == 0 { - items, err = i.defaultQueryItems(ctx, slice, options) - } else { - items, err = i.vectorQueryItems(ctx, slice, options) - } + items, err := i.makeBulkItems(ctx, slice, options) if err != nil { return nil, err } @@ -128,73 +117,42 @@ func (i *Indexer) Store(ctx context.Context, docs []*schema.Document, opts ...in return ids, nil } -func (i *Indexer) defaultQueryItems(_ context.Context, docs []*schema.Document, _ *indexer.Options) (items []esutil.BulkIndexerItem, err error) { - items, err = iterWithErr(docs, func(doc *schema.Document) (item esutil.BulkIndexerItem, err error) { - b, err := json.Marshal(toESDoc(doc)) - if err != nil { - return item, err - } - - return esutil.BulkIndexerItem{ - Index: i.config.Index, - Action: "index", - DocumentID: doc.ID, - Body: bytes.NewReader(b), - }, nil - }) - - if err != nil { - return nil, err - } - - return items, nil -} - -func (i *Indexer) vectorQueryItems(ctx context.Context, docs []*schema.Document, options *indexer.Options) (items []esutil.BulkIndexerItem, err error) { +func (i *Indexer) makeBulkItems(ctx context.Context, docs []*schema.Document, options *indexer.Options) (items []esutil.BulkIndexerItem, err error) { emb := options.Embedding items, err = iterWithErr(docs, func(doc *schema.Document) (item esutil.BulkIndexerItem, err error) { - mp := toESDoc(doc) - texts := make([]string, 0, len(i.config.VectorFields)) - for _, kv := range i.config.VectorFields { - str, ok := kv.FieldName.Find(doc) - if !ok { - return item, fmt.Errorf("[vectorQueryItems] field name not found or type incorrect, name=%s, doc=%v", kv.FieldName, doc) - } - - if kv.FieldName == field_mapping.DocFieldNameContent && len(doc.DenseVector()) > 0 { - mp[string(kv.FieldNameVector)] = doc.DenseVector() - } else { - texts = append(texts, str) - } + fields, needEmbeddingFields, err := i.config.FieldMapping(ctx, doc) + if err != nil { + return item, fmt.Errorf("[makeBulkItems] FieldMapping failed, %w", err) } - if len(texts) > 0 { + if len(needEmbeddingFields) > 0 { if emb == nil { - return item, fmt.Errorf("[vectorQueryItems] embedding not provided") + return item, fmt.Errorf("[makeBulkItems] embedding method not provided") + } + + tuples := make([]tuple[string, int], 0, len(fields)) + texts := make([]string, 0, len(fields)) + for k, text := range needEmbeddingFields { + tuples = append(tuples, tuple[string, int]{k, len(texts)}) + texts = append(texts, text) } vectors, err := emb.EmbedStrings(i.makeEmbeddingCtx(ctx, emb), texts) if err != nil { - return item, fmt.Errorf("[vectorQueryItems] embedding failed, %w", err) + return item, fmt.Errorf("[makeBulkItems] embedding failed, %w", err) } if len(vectors) != len(texts) { - return item, fmt.Errorf("[vectorQueryItems] invalid vector length, expected=%d, got=%d", len(texts), len(vectors)) + return item, fmt.Errorf("[makeBulkItems] invalid vector length, expected=%d, got=%d", len(texts), len(vectors)) } - vIdx := 0 - for _, kv := range i.config.VectorFields { - if kv.FieldName == field_mapping.DocFieldNameContent && len(doc.DenseVector()) > 0 { - continue - } - - mp[string(kv.FieldNameVector)] = vectors[vIdx] - vIdx++ + for _, t := range tuples { + fields[t.A] = vectors[t.B] } } - b, err := json.Marshal(mp) + b, err := json.Marshal(fields) if err != nil { return item, err } diff --git a/components/indexer/es8/indexer_test.go b/components/indexer/es8/indexer_test.go index 7f1356e..a1535a4 100644 --- a/components/indexer/es8/indexer_test.go +++ b/components/indexer/es8/indexer_test.go @@ -28,99 +28,85 @@ import ( "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/schema" - - "github.com/cloudwego/eino-ext/components/indexer/es8/field_mapping" ) func TestVectorQueryItems(t *testing.T) { - PatchConvey("test vectorQueryItems", t, func() { + PatchConvey("test makeBulkItems", t, func() { ctx := context.Background() extField := "extra_field" - d1 := &schema.Document{ID: "123", Content: "asd"} - d1.WithDenseVector([]float64{2.3, 4.4}) - field_mapping.SetExtraDataFields(d1, map[string]interface{}{extField: "ext_1"}) - - d2 := &schema.Document{ID: "456", Content: "qwe"} - field_mapping.SetExtraDataFields(d2, map[string]interface{}{extField: "ext_2"}) - + d1 := &schema.Document{ID: "123", Content: "asd", MetaData: map[string]any{extField: "ext_1"}} + d2 := &schema.Document{ID: "456", Content: "qwe", MetaData: map[string]any{extField: "ext_2"}} docs := []*schema.Document{d1, d2} - PatchConvey("test field not found", func() { + PatchConvey("test FieldMapping error", func() { + mockErr := fmt.Errorf("test err") i := &Indexer{ config: &IndexerConfig{ Index: "mock_index", - VectorFields: []field_mapping.FieldKV{ - field_mapping.DefaultFieldKV("not_found_field"), + FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) { + return nil, nil, mockErr }, }, } - bulks, err := i.vectorQueryItems(ctx, docs, &indexer.Options{ + bulks, err := i.makeBulkItems(ctx, docs, &indexer.Options{ Embedding: &mockEmbedding{size: []int{1}, mockVector: []float64{2.1}}, }) - convey.So(err, convey.ShouldBeError, fmt.Sprintf("[vectorQueryItems] field name not found or type incorrect, name=not_found_field, doc=%v", d1)) + convey.So(err, convey.ShouldBeError, fmt.Errorf("[makeBulkItems] FieldMapping failed, %w", mockErr)) convey.So(len(bulks), convey.ShouldEqual, 0) }) PatchConvey("test emb not provided", func() { i := &Indexer{ config: &IndexerConfig{ - Index: "mock_index", - VectorFields: []field_mapping.FieldKV{ - field_mapping.DefaultFieldKV(field_mapping.DocFieldNameContent), - field_mapping.DefaultFieldKV(field_mapping.FieldName(extField)), - }, + Index: "mock_index", + FieldMapping: defaultFieldMapping, }, } - bulks, err := i.vectorQueryItems(ctx, docs, &indexer.Options{Embedding: nil}) - convey.So(err, convey.ShouldBeError, "[vectorQueryItems] embedding not provided") + bulks, err := i.makeBulkItems(ctx, docs, &indexer.Options{Embedding: nil}) + convey.So(err, convey.ShouldBeError, "[makeBulkItems] embedding method not provided") convey.So(len(bulks), convey.ShouldEqual, 0) }) PatchConvey("test vector size invalid", func() { i := &Indexer{ config: &IndexerConfig{ - Index: "mock_index", - VectorFields: []field_mapping.FieldKV{ - field_mapping.DefaultFieldKV(field_mapping.DocFieldNameContent), - field_mapping.DefaultFieldKV(field_mapping.FieldName(extField)), - }, + Index: "mock_index", + FieldMapping: defaultFieldMapping, }, } - bulks, err := i.vectorQueryItems(ctx, docs, &indexer.Options{ + bulks, err := i.makeBulkItems(ctx, docs, &indexer.Options{ Embedding: &mockEmbedding{size: []int{2, 2}, mockVector: []float64{2.1}}, }) - convey.So(err, convey.ShouldBeError, "[vectorQueryItems] invalid vector length, expected=1, got=2") + convey.So(err, convey.ShouldBeError, "[makeBulkItems] invalid vector length, expected=1, got=2") convey.So(len(bulks), convey.ShouldEqual, 0) }) PatchConvey("test success", func() { i := &Indexer{ config: &IndexerConfig{ - Index: "mock_index", - VectorFields: []field_mapping.FieldKV{ - field_mapping.DefaultFieldKV(field_mapping.DocFieldNameContent), - field_mapping.DefaultFieldKV(field_mapping.FieldName(extField)), - }, + Index: "mock_index", + FieldMapping: defaultFieldMapping, }, } - bulks, err := i.vectorQueryItems(ctx, docs, &indexer.Options{ - Embedding: &mockEmbedding{size: []int{1, 2}, mockVector: []float64{2.1}}, + bulks, err := i.makeBulkItems(ctx, docs, &indexer.Options{ + Embedding: &mockEmbedding{size: []int{1, 1}, mockVector: []float64{2.1}}, }) convey.So(err, convey.ShouldBeNil) convey.So(len(bulks), convey.ShouldEqual, 2) exp := []string{ - `{"eino_doc_content":"asd","extra_field":"ext_1","vector_eino_doc_content":[2.3,4.4],"vector_extra_field":[2.1]}`, - `{"eino_doc_content":"qwe","extra_field":"ext_2","vector_eino_doc_content":[2.1],"vector_extra_field":[2.1]}`, + `{"content":"asd","meta_data":{"extra_field":"ext_1"},"vector_content":[2.1]}`, + `{"content":"qwe","meta_data":{"extra_field":"ext_2"},"vector_content":[2.1]}`, } for idx, item := range bulks { convey.So(item.Index, convey.ShouldEqual, i.config.Index) b, err := io.ReadAll(item.Body) + fmt.Println(string(b)) convey.So(err, convey.ShouldBeNil) convey.So(string(b), convey.ShouldEqual, exp[idx]) } @@ -147,3 +133,18 @@ func (m *mockEmbedding) EmbedStrings(ctx context.Context, texts []string, opts . return resp, nil } + +func defaultFieldMapping(ctx context.Context, doc *schema.Document) ( + fields map[string]any, needEmbeddingFields map[string]string, err error) { + + fields = map[string]any{ + "content": doc.Content, + "meta_data": doc.MetaData, + } + + needEmbeddingFields = map[string]string{ + "vector_content": doc.Content, + } + + return fields, needEmbeddingFields, nil +} diff --git a/components/indexer/es8/internal/consts.go b/components/indexer/es8/internal/consts.go deleted file mode 100644 index c515981..0000000 --- a/components/indexer/es8/internal/consts.go +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package internal - -const ( - DocExtraKeyEsFields = "_es_fields" // *schema.Document.MetaData key of es fields except content -) diff --git a/components/indexer/es8/utils.go b/components/indexer/es8/utils.go index a5f8952..e669079 100644 --- a/components/indexer/es8/utils.go +++ b/components/indexer/es8/utils.go @@ -16,27 +16,13 @@ package es8 -import ( - "github.com/cloudwego/eino/schema" - - "github.com/cloudwego/eino-ext/components/indexer/es8/field_mapping" -) - func GetType() string { return typ } -func toESDoc(doc *schema.Document) map[string]any { - mp := make(map[string]any) - if kvs, ok := field_mapping.GetExtraDataFields(doc); ok { - for k, v := range kvs { - mp[k] = v - } - } - - mp[field_mapping.DocFieldNameContent] = doc.Content - - return mp +type tuple[A, B any] struct { + A A + B B } func chunk[T any](slice []T, size int) [][]T { diff --git a/components/retriever/es8/consts.go b/components/retriever/es8/consts.go index 3f1ffcd..7f5da11 100644 --- a/components/retriever/es8/consts.go +++ b/components/retriever/es8/consts.go @@ -17,3 +17,7 @@ package es8 const typ = "ElasticSearch8" + +func GetType() string { + return typ +} diff --git a/components/retriever/es8/field_mapping/consts.go b/components/retriever/es8/field_mapping/consts.go deleted file mode 100644 index abe46ff..0000000 --- a/components/retriever/es8/field_mapping/consts.go +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package field_mapping - -const DocFieldNameContent = "eino_doc_content" diff --git a/components/retriever/es8/field_mapping/mapping.go b/components/retriever/es8/field_mapping/mapping.go deleted file mode 100644 index 6cc35e4..0000000 --- a/components/retriever/es8/field_mapping/mapping.go +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package field_mapping - -import ( - "fmt" - - "github.com/cloudwego/eino-ext/components/retriever/es8/internal" - "github.com/cloudwego/eino/schema" -) - -// GetDefaultVectorFieldKeyContent get default es key for Document.Content -func GetDefaultVectorFieldKeyContent() FieldName { - return defaultVectorFieldKeyContent -} - -// GetDefaultVectorFieldKey generate default vector field name from its field name -func GetDefaultVectorFieldKey(fieldName string) FieldName { - return FieldName(fmt.Sprintf("vector_%s", fieldName)) -} - -// GetExtraDataFields get data fields from *schema.Document -func GetExtraDataFields(doc *schema.Document) (fields map[string]interface{}, ok bool) { - if doc == nil || doc.MetaData == nil { - return nil, false - } - - fields, ok = doc.MetaData[internal.DocExtraKeyEsFields].(map[string]interface{}) - - return fields, ok -} - -type FieldKV struct { - // FieldNameVector vector field name (if needed) - FieldNameVector FieldName `json:"field_name_vector,omitempty"` - // FieldName field name - FieldName FieldName `json:"field_name,omitempty"` - // Value original value - Value string `json:"value,omitempty"` -} - -type FieldName string - -var defaultVectorFieldKeyContent = GetDefaultVectorFieldKey(DocFieldNameContent) diff --git a/components/retriever/es8/go.mod b/components/retriever/es8/go.mod index 5d6f722..d62c84e 100644 --- a/components/retriever/es8/go.mod +++ b/components/retriever/es8/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bytedance/mockey v1.2.13 - github.com/cloudwego/eino v0.3.5 + github.com/cloudwego/eino v0.3.6 github.com/elastic/go-elasticsearch/v8 v8.16.0 github.com/smartystreets/goconvey v1.8.1 github.com/stretchr/testify v1.9.0 diff --git a/components/retriever/es8/go.sum b/components/retriever/es8/go.sum index e701f60..7b9198f 100644 --- a/components/retriever/es8/go.sum +++ b/components/retriever/es8/go.sum @@ -15,6 +15,8 @@ github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/ github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/eino v0.3.5 h1:9PkAOX/phFifrGXkfl4L9rdecxOQJBJY1FtZqF4bz3c= github.com/cloudwego/eino v0.3.5/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo= +github.com/cloudwego/eino v0.3.6 h1:3yfdKKxMVWefdOyGXHuqUMM5cc9iioijj2mpPsDZKIg= +github.com/cloudwego/eino v0.3.6/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/components/retriever/es8/internal/consts.go b/components/retriever/es8/option.go similarity index 67% rename from components/retriever/es8/internal/consts.go rename to components/retriever/es8/option.go index bf3b10d..2dd3f21 100644 --- a/components/retriever/es8/internal/consts.go +++ b/components/retriever/es8/option.go @@ -14,9 +14,14 @@ * limitations under the License. */ -package internal +package es8 -const ( - DocExtraKeyEsFields = "_es_fields" // *schema.Document.MetaData key of es fields except content - DslFilterField = "_dsl_filter_functions" +import ( + "github.com/elastic/go-elasticsearch/v8/typedapi/types" ) + +// ESImplOptions es specified options +// Use retriever.GetImplSpecificOptions[ESImplOptions] to get ESImplOptions from options. +type ESImplOptions struct { + Filters []types.Query `json:"filters,omitempty"` +} diff --git a/components/retriever/es8/retriever.go b/components/retriever/es8/retriever.go index b441f08..b26c519 100644 --- a/components/retriever/es8/retriever.go +++ b/components/retriever/es8/retriever.go @@ -18,19 +18,16 @@ package es8 import ( "context" - "encoding/json" "fmt" "github.com/elastic/go-elasticsearch/v8" "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/schema" - - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" - "github.com/cloudwego/eino-ext/components/retriever/es8/internal" ) type RetrieverConfig struct { @@ -47,11 +44,17 @@ type RetrieverConfig struct { // use search_mode.SearchModeSparseVectorTextExpansion with search_mode.SparseVectorTextExpansionQuery // use search_mode.SearchModeRawStringRequest with json search request SearchMode SearchMode `json:"search_mode"` + // ResultParser parse document from es search hits. + // If ResultParser not provided, defaultResultParser will be used as default + ResultParser func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) // Embedding vectorization method, must provide when SearchMode needed Embedding embedding.Embedder } type SearchMode interface { + // BuildRequest generate search request from config, query and options. + // Additionally, some specified options (like filters for query) will be provided in options, + // and use retriever.GetImplSpecificOptions[options.ESImplOptions] to get it. BuildRequest(ctx context.Context, conf *RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) } @@ -65,6 +68,10 @@ func NewRetriever(_ context.Context, conf *RetrieverConfig) (*Retriever, error) return nil, fmt.Errorf("[NewRetriever] search mode not provided") } + if conf.ResultParser == nil { + return nil, fmt.Errorf("[NewRetriever] result parser not provided") + } + client, err := elasticsearch.NewTypedClient(conf.ESConfig) if err != nil { return nil, fmt.Errorf("[NewRetriever] new es client failed, %w", err) @@ -109,7 +116,7 @@ func (r *Retriever) Retrieve(ctx context.Context, query string, opts ...retrieve return nil, err } - docs, err = r.parseSearchResult(resp) + docs, err = r.parseSearchResult(ctx, resp) if err != nil { return nil, err } @@ -119,40 +126,13 @@ func (r *Retriever) Retrieve(ctx context.Context, query string, opts ...retrieve return docs, nil } -func (r *Retriever) parseSearchResult(resp *search.Response) (docs []*schema.Document, err error) { +func (r *Retriever) parseSearchResult(ctx context.Context, resp *search.Response) (docs []*schema.Document, err error) { docs = make([]*schema.Document, 0, len(resp.Hits.Hits)) for _, hit := range resp.Hits.Hits { - var raw map[string]any - if err = json.Unmarshal(hit.Source_, &raw); err != nil { - return nil, fmt.Errorf("[parseSearchResult] unexpected hit source type, source=%v", string(hit.Source_)) - } - - var id string - if hit.Id_ != nil { - id = *hit.Id_ - } - - content, ok := raw[field_mapping.DocFieldNameContent].(string) - if !ok { - return nil, fmt.Errorf("[parseSearchResult] content type not string, raw=%v", raw) - } - - expMap := make(map[string]any, len(raw)-1) - for k, v := range raw { - if k != internal.DocExtraKeyEsFields { - expMap[k] = v - } - } - - doc := &schema.Document{ - ID: id, - Content: content, - MetaData: map[string]any{internal.DocExtraKeyEsFields: expMap}, - } - - if hit.Score_ != nil { - doc.WithScore(float64(*hit.Score_)) + doc, err := r.config.ResultParser(ctx, hit) + if err != nil { + return nil, err } docs = append(docs, doc) diff --git a/components/retriever/es8/retriever_test.go b/components/retriever/es8/retriever_test.go index f7b2e60..38f9164 100644 --- a/components/retriever/es8/retriever_test.go +++ b/components/retriever/es8/retriever_test.go @@ -19,10 +19,12 @@ package es8 import ( "context" "encoding/json" + "fmt" "testing" "github.com/bytedance/mockey" "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" "github.com/elastic/go-elasticsearch/v8" "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" @@ -34,9 +36,31 @@ func TestNewRetriever(t *testing.T) { t.Run("retrieve_documents", func(t *testing.T) { r, err := NewRetriever(ctx, &RetrieverConfig{ - ESConfig: elasticsearch.Config{}, - Index: "eino_ut", - TopK: 10, + ESConfig: elasticsearch.Config{}, + Index: "eino_ut", + TopK: 10, + ResultParser: func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) { + var mp map[string]any + if err := json.Unmarshal(hit.Source_, &mp); err != nil { + return nil, err + } + + var id string + if hit.Id_ != nil { + id = *hit.Id_ + } + + content, ok := mp["eino_doc_content"].(string) + if !ok { + return nil, fmt.Errorf("content not found") + } + + return &schema.Document{ + ID: id, + Content: content, + MetaData: nil, + }, nil + }, SearchMode: &mockSearchMode{}, }) assert.NoError(t, err) diff --git a/components/retriever/es8/search_mode/approximate.go b/components/retriever/es8/search_mode/approximate.go index ee1825d..87e6e31 100644 --- a/components/retriever/es8/search_mode/approximate.go +++ b/components/retriever/es8/search_mode/approximate.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -18,16 +18,12 @@ package search_mode import ( "context" - "encoding/json" "fmt" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/cloudwego/eino/components/retriever" "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - - "github.com/cloudwego/eino/components/retriever" - - "github.com/cloudwego/eino-ext/components/retriever/es8" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) // SearchModeApproximate retrieve with multiple approximate strategy (filter+knn+rrf) @@ -38,48 +34,34 @@ func SearchModeApproximate(config *ApproximateConfig) es8.SearchMode { } type ApproximateConfig struct { + // QueryFieldName the name of query field, required when using Hybrid + QueryFieldName string + // VectorFieldName the name of the vector field to search against, required + VectorFieldName string // Hybrid if true, add filters and rff to knn query Hybrid bool - // Rrf is a method for combining multiple result sets, is used to + // RRF (Reciprocal Rank Fusion) is a method for combining multiple result sets, is used to // even the score from the knn query and text query - Rrf bool - // RrfRankConstant determines how much influence documents in + RRF bool + // RRFRankConstant determines how much influence documents in // individual result sets per query have over the final ranked result set - RrfRankConstant *int64 - // RrfWindowSize determines the size ptrWithoutZero the individual result sets per query - RrfWindowSize *int64 -} - -type ApproximateQuery struct { - // FieldKV es field info, QueryVectorBuilderModelID will be used if embedding not provided in config, - // and Embedding will be used if QueryVectorBuilderModelID is nil - FieldKV field_mapping.FieldKV `json:"field_kv"` + RRFRankConstant *int64 + // RRFWindowSize determines the size ptrWithoutZero the individual result sets per query + RRFWindowSize *int64 // QueryVectorBuilderModelID the query vector builder model id // see: https://www.elastic.co/guide/en/machine-learning/8.16/ml-nlp-text-emb-vector-search-example.html - QueryVectorBuilderModelID *string `json:"query_vector_builder_model_id,omitempty"` + QueryVectorBuilderModelID *string // Boost Floating point number used to decrease or increase the relevance scores ptrWithoutZero the query. // Boost values are relative to the default value ptrWithoutZero 1.0. // A boost value between 0 and 1.0 decreases the relevance score. // A value greater than 1.0 increases the relevance score. - Boost *float32 `json:"boost,omitempty"` - // Filters for the kNN search query - Filters []types.Query `json:"filters,omitempty"` + Boost *float32 // K The final number ptrWithoutZero nearest neighbors to return as top hits - K *int `json:"k,omitempty"` + K *int // NumCandidates The number ptrWithoutZero nearest neighbor candidates to consider per shard - NumCandidates *int `json:"num_candidates,omitempty"` + NumCandidates *int // Similarity The minimum similarity for a vector to be considered a match - Similarity *float32 `json:"similarity,omitempty"` -} - -// ToRetrieverQuery convert approximate query to string query -func (a *ApproximateQuery) ToRetrieverQuery() (string, error) { - b, err := json.Marshal(a) - if err != nil { - return "", fmt.Errorf("[ToRetrieverQuery] convert query failed, %w", err) - } - - return string(b), nil + Similarity *float32 } type approximate struct { @@ -88,41 +70,38 @@ type approximate struct { func (a *approximate) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) { - options := retriever.GetCommonOptions(&retriever.Options{ + co := retriever.GetCommonOptions(&retriever.Options{ Index: ptrWithoutZero(conf.Index), TopK: ptrWithoutZero(conf.TopK), ScoreThreshold: conf.ScoreThreshold, Embedding: conf.Embedding, }, opts...) - var appReq ApproximateQuery - if err := json.Unmarshal([]byte(query), &appReq); err != nil { - return nil, fmt.Errorf("[BuildRequest][SearchModeApproximate] parse query failed, %w", err) - } + io := retriever.GetImplSpecificOptions[es8.ESImplOptions](nil, opts...) knn := types.KnnSearch{ - Boost: appReq.Boost, - Field: string(appReq.FieldKV.FieldNameVector), - Filter: appReq.Filters, - K: appReq.K, - NumCandidates: appReq.NumCandidates, + Boost: a.config.Boost, + Field: a.config.VectorFieldName, + Filter: io.Filters, + K: a.config.K, + NumCandidates: a.config.NumCandidates, QueryVector: nil, QueryVectorBuilder: nil, - Similarity: appReq.Similarity, + Similarity: a.config.Similarity, } - if appReq.QueryVectorBuilderModelID != nil { + if a.config.QueryVectorBuilderModelID != nil { knn.QueryVectorBuilder = &types.QueryVectorBuilder{TextEmbedding: &types.TextEmbedding{ - ModelId: *appReq.QueryVectorBuilderModelID, - ModelText: appReq.FieldKV.Value, + ModelId: *a.config.QueryVectorBuilderModelID, + ModelText: query, }} } else { - emb := options.Embedding + emb := co.Embedding if emb == nil { return nil, fmt.Errorf("[BuildRequest][SearchModeApproximate] embedding not provided") } - vector, err := emb.EmbedStrings(makeEmbeddingCtx(ctx, emb), []string{appReq.FieldKV.Value}) + vector, err := emb.EmbedStrings(makeEmbeddingCtx(ctx, emb), []string{query}) if err != nil { return nil, fmt.Errorf("[BuildRequest][SearchModeApproximate] embedding failed, %w", err) } @@ -134,32 +113,32 @@ func (a *approximate) BuildRequest(ctx context.Context, conf *es8.RetrieverConfi knn.QueryVector = f64To32(vector[0]) } - req := &search.Request{Knn: []types.KnnSearch{knn}, Size: options.TopK} + req := &search.Request{Knn: []types.KnnSearch{knn}, Size: co.TopK} if a.config.Hybrid { req.Query = &types.Query{ Bool: &types.BoolQuery{ - Filter: appReq.Filters, + Filter: io.Filters, Must: []types.Query{ { Match: map[string]types.MatchQuery{ - string(appReq.FieldKV.FieldName): {Query: appReq.FieldKV.Value}, + a.config.QueryFieldName: {Query: query}, }, }, }, }, } - if a.config.Rrf { + if a.config.RRF { req.Rank = &types.RankContainer{Rrf: &types.RrfRank{ - RankConstant: a.config.RrfRankConstant, - RankWindowSize: a.config.RrfWindowSize, + RankConstant: a.config.RRFRankConstant, + RankWindowSize: a.config.RRFWindowSize, }} } } - if options.ScoreThreshold != nil { - req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold)) + if co.ScoreThreshold != nil { + req.MinScore = (*types.Float64)(ptrWithoutZero(*co.ScoreThreshold)) } return req, nil diff --git a/components/retriever/es8/search_mode/approximate_test.go b/components/retriever/es8/search_mode/approximate_test.go index 36e3afb..897dd94 100644 --- a/components/retriever/es8/search_mode/approximate_test.go +++ b/components/retriever/es8/search_mode/approximate_test.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -22,65 +22,44 @@ import ( "testing" . "github.com/bytedance/mockey" - "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "github.com/smartystreets/goconvey/convey" - + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/retriever" - - "github.com/cloudwego/eino-ext/components/retriever/es8" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/smartystreets/goconvey/convey" ) func TestSearchModeApproximate(t *testing.T) { PatchConvey("test SearchModeApproximate", t, func() { - PatchConvey("test ToRetrieverQuery", func() { - aq := &ApproximateQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - Boost: ptrWithoutZero(float32(1.0)), - K: ptrWithoutZero(10), - NumCandidates: ptrWithoutZero(100), - Similarity: ptrWithoutZero(float32(0.5)), - } - - sq, err := aq.ToRetrieverQuery() - convey.So(err, convey.ShouldBeNil) - convey.So(sq, convey.ShouldEqual, `{"field_kv":{"field_name_vector":"vector_eino_doc_content","field_name":"eino_doc_content","value":"content"},"boost":1,"filters":[{"match":{"label":{"query":"good"}}}],"k":10,"num_candidates":100,"similarity":0.5}`) - }) - PatchConvey("test BuildRequest", func() { ctx := context.Background() + queryFieldName := "eino_doc_content" + vectorFieldName := "vector_eino_doc_content" + query := "content" PatchConvey("test QueryVectorBuilderModelID", func() { - a := &approximate{config: &ApproximateConfig{}} - aq := &ApproximateQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, + a := &approximate{config: &ApproximateConfig{ + QueryFieldName: queryFieldName, + VectorFieldName: vectorFieldName, + Hybrid: false, + RRF: false, + RRFRankConstant: nil, + RRFWindowSize: nil, QueryVectorBuilderModelID: ptrWithoutZero("mock_model"), - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - Boost: ptrWithoutZero(float32(1.0)), - K: ptrWithoutZero(10), - NumCandidates: ptrWithoutZero(100), - Similarity: ptrWithoutZero(float32(0.5)), - } - - sq, err := aq.ToRetrieverQuery() - convey.So(err, convey.ShouldBeNil) + Boost: ptrWithoutZero(float32(1.0)), + K: ptrWithoutZero(10), + NumCandidates: ptrWithoutZero(100), + Similarity: ptrWithoutZero(float32(0.5)), + }} conf := &es8.RetrieverConfig{} - req, err := a.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(nil)) + req, err := a.BuildRequest(ctx, conf, query, + retriever.WithEmbedding(nil), + retriever.WrapImplSpecificOptFn[es8.ESImplOptions](func(o *es8.ESImplOptions) { + o.Filters = []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + } + })) convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) convey.So(err, convey.ShouldBeNil) @@ -88,26 +67,28 @@ func TestSearchModeApproximate(t *testing.T) { }) PatchConvey("test embedding", func() { - a := &approximate{config: &ApproximateConfig{}} - aq := &ApproximateQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - Boost: ptrWithoutZero(float32(1.0)), - K: ptrWithoutZero(10), - NumCandidates: ptrWithoutZero(100), - Similarity: ptrWithoutZero(float32(0.5)), - } + a := &approximate{config: &ApproximateConfig{ + QueryFieldName: queryFieldName, + VectorFieldName: vectorFieldName, + Hybrid: false, + RRF: false, + RRFRankConstant: nil, + RRFWindowSize: nil, + QueryVectorBuilderModelID: nil, + Boost: ptrWithoutZero(float32(1.0)), + K: ptrWithoutZero(10), + NumCandidates: ptrWithoutZero(100), + Similarity: ptrWithoutZero(float32(0.5)), + }} - sq, err := aq.ToRetrieverQuery() - convey.So(err, convey.ShouldBeNil) conf := &es8.RetrieverConfig{} - req, err := a.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}})) + req, err := a.BuildRequest(ctx, conf, query, + retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}), + retriever.WrapImplSpecificOptFn[es8.ESImplOptions](func(o *es8.ESImplOptions) { + o.Filters = []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + } + })) convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) convey.So(err, convey.ShouldBeNil) @@ -116,34 +97,29 @@ func TestSearchModeApproximate(t *testing.T) { PatchConvey("test hybrid with rrf", func() { a := &approximate{config: &ApproximateConfig{ - Hybrid: true, - Rrf: true, - RrfRankConstant: ptrWithoutZero(int64(10)), - RrfWindowSize: ptrWithoutZero(int64(5)), + QueryFieldName: queryFieldName, + VectorFieldName: vectorFieldName, + Hybrid: true, + RRF: true, + RRFRankConstant: ptrWithoutZero(int64(10)), + RRFWindowSize: ptrWithoutZero(int64(5)), + QueryVectorBuilderModelID: nil, + Boost: ptrWithoutZero(float32(1.0)), + K: ptrWithoutZero(10), + NumCandidates: ptrWithoutZero(100), + Similarity: ptrWithoutZero(float32(0.5)), }} - aq := &ApproximateQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - Boost: ptrWithoutZero(float32(1.0)), - K: ptrWithoutZero(10), - NumCandidates: ptrWithoutZero(100), - Similarity: ptrWithoutZero(float32(0.5)), - } - - sq, err := aq.ToRetrieverQuery() - convey.So(err, convey.ShouldBeNil) - conf := &es8.RetrieverConfig{} - req, err := a.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}), + req, err := a.BuildRequest(ctx, conf, query, + retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}), retriever.WithTopK(10), - retriever.WithScoreThreshold(1.1)) + retriever.WithScoreThreshold(1.1), + retriever.WrapImplSpecificOptFn[es8.ESImplOptions](func(o *es8.ESImplOptions) { + o.Filters = []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + } + })) convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) convey.So(err, convey.ShouldBeNil) diff --git a/components/retriever/es8/search_mode/dense_vector_similarity.go b/components/retriever/es8/search_mode/dense_vector_similarity.go index 18e58dd..cd4df4d 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -21,34 +21,16 @@ import ( "encoding/json" "fmt" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/cloudwego/eino/components/retriever" "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - - "github.com/cloudwego/eino/components/retriever" - - "github.com/cloudwego/eino-ext/components/retriever/es8" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) // SearchModeDenseVectorSimilarity calculate embedding similarity between dense_vector field and query // see: https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html#vector-functions -func SearchModeDenseVectorSimilarity(typ DenseVectorSimilarityType) es8.SearchMode { - return &denseVectorSimilarity{script: denseVectorScriptMap[typ]} -} - -type DenseVectorSimilarityQuery struct { - FieldKV field_mapping.FieldKV `json:"field_kv"` - Filters []types.Query `json:"filters,omitempty"` -} - -// ToRetrieverQuery convert approximate query to string query -func (d *DenseVectorSimilarityQuery) ToRetrieverQuery() (string, error) { - b, err := json.Marshal(d) - if err != nil { - return "", fmt.Errorf("[ToRetrieverQuery] convert query failed, %w", err) - } - - return string(b), nil +func SearchModeDenseVectorSimilarity(typ DenseVectorSimilarityType, vectorFieldName string) es8.SearchMode { + return &denseVectorSimilarity{fmt.Sprintf(denseVectorScriptMap[typ], vectorFieldName)} } type denseVectorSimilarity struct { @@ -58,24 +40,21 @@ type denseVectorSimilarity struct { func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) { - options := retriever.GetCommonOptions(&retriever.Options{ + co := retriever.GetCommonOptions(&retriever.Options{ Index: ptrWithoutZero(conf.Index), TopK: ptrWithoutZero(conf.TopK), ScoreThreshold: conf.ScoreThreshold, Embedding: conf.Embedding, }, opts...) - var dq DenseVectorSimilarityQuery - if err := json.Unmarshal([]byte(query), &dq); err != nil { - return nil, fmt.Errorf("[BuildRequest][SearchModeDenseVectorSimilarity] parse query failed, %w", err) - } + io := retriever.GetImplSpecificOptions[es8.ESImplOptions](nil, opts...) - emb := options.Embedding + emb := co.Embedding if emb == nil { return nil, fmt.Errorf("[BuildRequest][SearchModeDenseVectorSimilarity] embedding not provided") } - vector, err := emb.EmbedStrings(makeEmbeddingCtx(ctx, emb), []string{dq.FieldKV.Value}) + vector, err := emb.EmbedStrings(makeEmbeddingCtx(ctx, emb), []string{query}) if err != nil { return nil, fmt.Errorf("[BuildRequest][SearchModeDenseVectorSimilarity] embedding failed, %w", err) } @@ -92,15 +71,15 @@ func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.Retr q := &types.Query{ ScriptScore: &types.ScriptScoreQuery{ Script: types.Script{ - Source: ptrWithoutZero(fmt.Sprintf(d.script, dq.FieldKV.FieldNameVector)), + Source: ptrWithoutZero(d.script), Params: map[string]json.RawMessage{"embedding": vb}, }, }, } - if len(dq.Filters) > 0 { + if len(io.Filters) > 0 { q.ScriptScore.Query = &types.Query{ - Bool: &types.BoolQuery{Filter: dq.Filters}, + Bool: &types.BoolQuery{Filter: io.Filters}, } } else { q.ScriptScore.Query = &types.Query{ @@ -108,9 +87,9 @@ func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.Retr } } - req := &search.Request{Query: q, Size: options.TopK} - if options.ScoreThreshold != nil { - req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold)) + req := &search.Request{Query: q, Size: co.TopK} + if co.ScoreThreshold != nil { + req.MinScore = (*types.Float64)(ptrWithoutZero(*co.ScoreThreshold)) } return req, nil @@ -127,10 +106,10 @@ const ( var denseVectorScriptMap = map[DenseVectorSimilarityType]string{ DenseVectorSimilarityTypeCosineSimilarity: `cosineSimilarity(params.embedding, '%s') + 1.0`, - DenseVectorSimilarityTypeDotProduct: `"" - double value = dotProduct(params.embedding, '%s'); - return sigmoid(1, Math.E, -value); - ""`, + DenseVectorSimilarityTypeDotProduct: ` + double value = dotProduct(params.query_vector, '%s'); + return sigmoid(1, Math.E, -value); + `, DenseVectorSimilarityTypeL1Norm: `1 / (1 + l1norm(params.embedding, '%s'))`, DenseVectorSimilarityTypeL2Norm: `1 / (1 + l2norm(params.embedding, '%s'))`, } diff --git a/components/retriever/es8/search_mode/dense_vector_similarity_test.go b/components/retriever/es8/search_mode/dense_vector_similarity_test.go index a68c333..1a7f548 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity_test.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity_test.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -23,60 +23,30 @@ import ( "testing" . "github.com/bytedance/mockey" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/cloudwego/eino/components/retriever" "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/smartystreets/goconvey/convey" - - "github.com/cloudwego/eino/components/retriever" - - "github.com/cloudwego/eino-ext/components/retriever/es8" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) func TestSearchModeDenseVectorSimilarity(t *testing.T) { PatchConvey("test SearchModeDenseVectorSimilarity", t, func() { - PatchConvey("test ToRetrieverQuery", func() { - dq := &DenseVectorSimilarityQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - } - - sq, err := dq.ToRetrieverQuery() - convey.So(err, convey.ShouldBeNil) - convey.So(sq, convey.ShouldEqual, `{"field_kv":{"field_name_vector":"vector_eino_doc_content","field_name":"eino_doc_content","value":"content"},"filters":[{"match":{"label":{"query":"good"}}}]}`) - }) - PatchConvey("test BuildRequest", func() { ctx := context.Background() - d := &denseVectorSimilarity{script: denseVectorScriptMap[DenseVectorSimilarityTypeCosineSimilarity]} - dq := &DenseVectorSimilarityQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - } - sq, _ := dq.ToRetrieverQuery() + vectorFieldName := "vector_eino_doc_content" + d := SearchModeDenseVectorSimilarity(DenseVectorSimilarityTypeCosineSimilarity, vectorFieldName) + query := "content" PatchConvey("test embedding not provided", func() { - conf := &es8.RetrieverConfig{} - req, err := d.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(nil)) + req, err := d.BuildRequest(ctx, conf, query, retriever.WithEmbedding(nil)) convey.So(err, convey.ShouldBeError, "[BuildRequest][SearchModeDenseVectorSimilarity] embedding not provided") convey.So(req, convey.ShouldBeNil) }) PatchConvey("test vector size invalid", func() { conf := &es8.RetrieverConfig{} - req, err := d.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(mockEmbedding{size: 2, mockVector: []float64{1.1, 1.2}})) + req, err := d.BuildRequest(ctx, conf, query, retriever.WithEmbedding(mockEmbedding{size: 2, mockVector: []float64{1.1, 1.2}})) convey.So(err, convey.ShouldBeError, "[BuildRequest][SearchModeDenseVectorSimilarity] vector size invalid, expect=1, got=2") convey.So(req, convey.ShouldBeNil) }) @@ -84,18 +54,23 @@ func TestSearchModeDenseVectorSimilarity(t *testing.T) { PatchConvey("test success", func() { typ2Exp := map[DenseVectorSimilarityType]string{ DenseVectorSimilarityTypeCosineSimilarity: `{"min_score":1.1,"query":{"script_score":{"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}]}},"script":{"params":{"embedding":[1.1,1.2]},"source":"cosineSimilarity(params.embedding, 'vector_eino_doc_content') + 1.0"}}},"size":10}`, - DenseVectorSimilarityTypeDotProduct: `{"min_score":1.1,"query":{"script_score":{"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}]}},"script":{"params":{"embedding":[1.1,1.2]},"source":"\"\"\n double value = dotProduct(params.embedding, 'vector_eino_doc_content');\n return sigmoid(1, Math.E, -value); \n \"\""}}},"size":10}`, + DenseVectorSimilarityTypeDotProduct: `{"min_score":1.1,"query":{"script_score":{"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}]}},"script":{"params":{"embedding":[1.1,1.2]},"source":"\n double value = dotProduct(params.query_vector, 'vector_eino_doc_content');\n return sigmoid(1, Math.E, -value);\n "}}},"size":10}`, DenseVectorSimilarityTypeL1Norm: `{"min_score":1.1,"query":{"script_score":{"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}]}},"script":{"params":{"embedding":[1.1,1.2]},"source":"1 / (1 + l1norm(params.embedding, 'vector_eino_doc_content'))"}}},"size":10}`, DenseVectorSimilarityTypeL2Norm: `{"min_score":1.1,"query":{"script_score":{"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}]}},"script":{"params":{"embedding":[1.1,1.2]},"source":"1 / (1 + l2norm(params.embedding, 'vector_eino_doc_content'))"}}},"size":10}`, } for typ, exp := range typ2Exp { - similarity := &denseVectorSimilarity{script: denseVectorScriptMap[typ]} + similarity := SearchModeDenseVectorSimilarity(typ, vectorFieldName) conf := &es8.RetrieverConfig{} - req, err := similarity.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}), + req, err := similarity.BuildRequest(ctx, conf, query, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}), retriever.WithTopK(10), - retriever.WithScoreThreshold(1.1)) + retriever.WithScoreThreshold(1.1), + retriever.WrapImplSpecificOptFn[es8.ESImplOptions](func(o *es8.ESImplOptions) { + o.Filters = []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + } + })) convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) diff --git a/components/retriever/es8/search_mode/exact_match.go b/components/retriever/es8/search_mode/exact_match.go index 81c7324..0282eef 100644 --- a/components/retriever/es8/search_mode/exact_match.go +++ b/components/retriever/es8/search_mode/exact_match.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -19,20 +19,19 @@ package search_mode import ( "context" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/cloudwego/eino/components/retriever" "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - - "github.com/cloudwego/eino/components/retriever" - - "github.com/cloudwego/eino-ext/components/retriever/es8" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) -func SearchModeExactMatch() es8.SearchMode { - return &exactMatch{} +func SearchModeExactMatch(queryFieldName string) es8.SearchMode { + return &exactMatch{queryFieldName} } -type exactMatch struct{} +type exactMatch struct { + name string +} func (e exactMatch) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) { @@ -46,7 +45,7 @@ func (e exactMatch) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, q := &types.Query{ Match: map[string]types.MatchQuery{ - field_mapping.DocFieldNameContent: {Query: query}, + e.name: {Query: query}, }, } diff --git a/components/retriever/es8/search_mode/exact_match_test.go b/components/retriever/es8/search_mode/exact_match_test.go new file mode 100644 index 0000000..fc90250 --- /dev/null +++ b/components/retriever/es8/search_mode/exact_match_test.go @@ -0,0 +1,41 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package search_mode + +import ( + "context" + "encoding/json" + "testing" + + . "github.com/bytedance/mockey" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/smartystreets/goconvey/convey" +) + +func TestSearchModeExactMatch(t *testing.T) { + PatchConvey("test SearchModeExactMatch", t, func() { + ctx := context.Background() + conf := &es8.RetrieverConfig{} + searchMode := SearchModeExactMatch("test_field") + req, err := searchMode.BuildRequest(ctx, conf, "test_query") + convey.So(err, convey.ShouldBeNil) + b, err := json.Marshal(req) + convey.So(err, convey.ShouldBeNil) + convey.So(string(b), convey.ShouldEqual, `{"query":{"match":{"test_field":{"query":"test_query"}}}}`) + }) + +} diff --git a/components/retriever/es8/search_mode/raw_string.go b/components/retriever/es8/search_mode/raw_string.go index 82a7aa4..01eccd9 100644 --- a/components/retriever/es8/search_mode/raw_string.go +++ b/components/retriever/es8/search_mode/raw_string.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -19,11 +19,9 @@ package search_mode import ( "context" - "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" - - "github.com/cloudwego/eino/components/retriever" - "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/cloudwego/eino/components/retriever" + "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" ) func SearchModeRawStringRequest() es8.SearchMode { diff --git a/components/retriever/es8/search_mode/raw_string_test.go b/components/retriever/es8/search_mode/raw_string_test.go new file mode 100644 index 0000000..75d2619 --- /dev/null +++ b/components/retriever/es8/search_mode/raw_string_test.go @@ -0,0 +1,48 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package search_mode + +import ( + "context" + "testing" + + . "github.com/bytedance/mockey" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/smartystreets/goconvey/convey" +) + +func TestSearchModeRawStringRequest(t *testing.T) { + PatchConvey("test SearchModeRawStringRequest", t, func() { + ctx := context.Background() + conf := &es8.RetrieverConfig{} + searchMode := SearchModeRawStringRequest() + + PatchConvey("test from json error", func() { + r, err := searchMode.BuildRequest(ctx, conf, "test_query") + convey.So(err, convey.ShouldNotBeNil) + convey.So(r, convey.ShouldBeNil) + }) + + PatchConvey("test success", func() { + q := `{"query":{"match":{"test_field":{"query":"test_query"}}}}` + r, err := searchMode.BuildRequest(ctx, conf, q) + convey.So(err, convey.ShouldBeNil) + convey.So(r, convey.ShouldNotBeNil) + convey.So(r.Query.Match["test_field"].Query, convey.ShouldEqual, "test_query") + }) + }) +} diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go index bfa75be..0b9a999 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -18,63 +18,42 @@ package search_mode import ( "context" - "encoding/json" "fmt" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/cloudwego/eino/components/retriever" "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - - "github.com/cloudwego/eino/components/retriever" - - "github.com/cloudwego/eino-ext/components/retriever/es8" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) // SearchModeSparseVectorTextExpansion convert the query text into a list ptrWithoutZero token-weight pairs, // which are then used in a query against a sparse vector // see: https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-text-expansion-query.html -func SearchModeSparseVectorTextExpansion(modelID string) es8.SearchMode { - return &sparseVectorTextExpansion{modelID} -} - -type SparseVectorTextExpansionQuery struct { - FieldKV field_mapping.FieldKV `json:"field_kv"` - Filters []types.Query `json:"filters,omitempty"` -} - -// ToRetrieverQuery convert approximate query to string query -func (s *SparseVectorTextExpansionQuery) ToRetrieverQuery() (string, error) { - b, err := json.Marshal(s) - if err != nil { - return "", fmt.Errorf("[ToRetrieverQuery] convert query failed, %w", err) - } - - return string(b), nil +func SearchModeSparseVectorTextExpansion(modelID, vectorFieldName string) es8.SearchMode { + return &sparseVectorTextExpansion{modelID, vectorFieldName} } type sparseVectorTextExpansion struct { - modelID string + modelID string + vectorFieldName string } func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) { - options := retriever.GetCommonOptions(&retriever.Options{ + co := retriever.GetCommonOptions(&retriever.Options{ Index: ptrWithoutZero(conf.Index), TopK: ptrWithoutZero(conf.TopK), ScoreThreshold: conf.ScoreThreshold, Embedding: conf.Embedding, }, opts...) - var sq SparseVectorTextExpansionQuery - if err := json.Unmarshal([]byte(query), &sq); err != nil { - return nil, fmt.Errorf("[BuildRequest][SearchModeSparseVectorTextExpansion] parse query failed, %w", err) - } + io := retriever.GetImplSpecificOptions[es8.ESImplOptions](nil, opts...) - name := fmt.Sprintf("%s.tokens", sq.FieldKV.FieldNameVector) + name := fmt.Sprintf("%s.tokens", s.vectorFieldName) teq := types.TextExpansionQuery{ ModelId: s.modelID, - ModelText: sq.FieldKV.Value, + ModelText: query, } q := &types.Query{ @@ -82,13 +61,13 @@ func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, conf *es8.R Must: []types.Query{ {TextExpansion: map[string]types.TextExpansionQuery{name: teq}}, }, - Filter: sq.Filters, + Filter: io.Filters, }, } - req := &search.Request{Query: q, Size: options.TopK} - if options.ScoreThreshold != nil { - req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold)) + req := &search.Request{Query: q, Size: co.TopK} + if co.ScoreThreshold != nil { + req.MinScore = (*types.Float64)(ptrWithoutZero(*co.ScoreThreshold)) } return req, nil diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go index f4c42a9..019d0ce 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -22,55 +22,28 @@ import ( "testing" . "github.com/bytedance/mockey" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/cloudwego/eino/components/retriever" "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/smartystreets/goconvey/convey" - - "github.com/cloudwego/eino/components/retriever" - - "github.com/cloudwego/eino-ext/components/retriever/es8" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) func TestSearchModeSparseVectorTextExpansion(t *testing.T) { PatchConvey("test SearchModeSparseVectorTextExpansion", t, func() { - PatchConvey("test ToRetrieverQuery", func() { - sq := &SparseVectorTextExpansionQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - } - - ssq, err := sq.ToRetrieverQuery() - convey.So(err, convey.ShouldBeNil) - convey.So(ssq, convey.ShouldEqual, `{"field_kv":{"field_name_vector":"vector_eino_doc_content","field_name":"eino_doc_content","value":"content"},"filters":[{"match":{"label":{"query":"good"}}}]}`) - - }) - PatchConvey("test BuildRequest", func() { ctx := context.Background() - s := SearchModeSparseVectorTextExpansion("mock_model_id") - sq := &SparseVectorTextExpansionQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - } - - query, _ := sq.ToRetrieverQuery() + vectorFieldName := "vector_eino_doc_content" + s := SearchModeSparseVectorTextExpansion("mock_model_id", vectorFieldName) conf := &es8.RetrieverConfig{} - req, err := s.BuildRequest(ctx, conf, query, + req, err := s.BuildRequest(ctx, conf, "content", retriever.WithTopK(10), - retriever.WithScoreThreshold(1.1)) + retriever.WithScoreThreshold(1.1), + retriever.WrapImplSpecificOptFn[es8.ESImplOptions](func(o *es8.ESImplOptions) { + o.Filters = []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + } + })) convey.So(err, convey.ShouldBeNil) convey.So(req, convey.ShouldNotBeNil) diff --git a/components/retriever/es8/search_mode/utils.go b/components/retriever/es8/search_mode/utils.go index cc54479..ebddb7e 100644 --- a/components/retriever/es8/search_mode/utils.go +++ b/components/retriever/es8/search_mode/utils.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 *