From b99b464a3f2fe0a4205bb763627b868112e87523 Mon Sep 17 00:00:00 2001 From: xuzhaonan Date: Tue, 26 Nov 2024 16:27:18 +0800 Subject: [PATCH 01/11] feat: es8 indexer retriever --- components/indexer/es8/consts.go | 7 + .../indexer/es8/field_mapping/consts.go | 3 + .../es8/field_mapping/field_mapping.go | 64 +++++ components/indexer/es8/go.mod | 46 ++++ components/indexer/es8/go.sum | 164 +++++++++++++ components/indexer/es8/indexer.go | 220 ++++++++++++++++++ components/indexer/es8/indexer_test.go | 132 +++++++++++ components/indexer/es8/internal/consts.go | 5 + components/indexer/es8/utils.go | 63 +++++ components/retriever/es8/consts.go | 3 + .../retriever/es8/field_mapping/consts.go | 3 + .../retriever/es8/field_mapping/mapping.go | 42 ++++ components/retriever/es8/go.mod | 46 ++++ components/retriever/es8/go.sum | 164 +++++++++++++ components/retriever/es8/internal/consts.go | 6 + components/retriever/es8/retriever.go | 150 ++++++++++++ .../retriever/es8/search_mode/approximate.go | 140 +++++++++++ .../es8/search_mode/approximate_test.go | 148 ++++++++++++ .../search_mode/dense_vector_similarity.go | 109 +++++++++ .../dense_vector_similarity_test.go | 87 +++++++ .../retriever/es8/search_mode/exact_match.go | 32 +++ .../retriever/es8/search_mode/interface.go | 13 ++ .../retriever/es8/search_mode/raw_string.go | 24 ++ .../sparse_vector_text_expansion.go | 68 ++++++ .../sparse_vector_text_expansion_test.go | 59 +++++ components/retriever/es8/search_mode/utils.go | 36 +++ 26 files changed, 1834 insertions(+) create mode 100644 components/indexer/es8/consts.go create mode 100644 components/indexer/es8/field_mapping/consts.go create mode 100644 components/indexer/es8/field_mapping/field_mapping.go create mode 100644 components/indexer/es8/go.mod create mode 100644 components/indexer/es8/go.sum create mode 100644 components/indexer/es8/indexer.go create mode 100644 components/indexer/es8/indexer_test.go create mode 100644 components/indexer/es8/internal/consts.go create mode 100644 components/indexer/es8/utils.go create mode 100644 components/retriever/es8/consts.go create mode 100644 components/retriever/es8/field_mapping/consts.go create mode 100644 components/retriever/es8/field_mapping/mapping.go create mode 100644 components/retriever/es8/go.mod create mode 100644 components/retriever/es8/go.sum create mode 100644 components/retriever/es8/internal/consts.go create mode 100644 components/retriever/es8/retriever.go create mode 100644 components/retriever/es8/search_mode/approximate.go create mode 100644 components/retriever/es8/search_mode/approximate_test.go create mode 100644 components/retriever/es8/search_mode/dense_vector_similarity.go create mode 100644 components/retriever/es8/search_mode/dense_vector_similarity_test.go create mode 100644 components/retriever/es8/search_mode/exact_match.go create mode 100644 components/retriever/es8/search_mode/interface.go create mode 100644 components/retriever/es8/search_mode/raw_string.go create mode 100644 components/retriever/es8/search_mode/sparse_vector_text_expansion.go create mode 100644 components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go create mode 100644 components/retriever/es8/search_mode/utils.go diff --git a/components/indexer/es8/consts.go b/components/indexer/es8/consts.go new file mode 100644 index 0000000..dc0521c --- /dev/null +++ b/components/indexer/es8/consts.go @@ -0,0 +1,7 @@ +package es8 + +const typ = "ElasticSearch8" + +const ( + defaultBatchSize = 5 +) diff --git a/components/indexer/es8/field_mapping/consts.go b/components/indexer/es8/field_mapping/consts.go new file mode 100644 index 0000000..80438fa --- /dev/null +++ b/components/indexer/es8/field_mapping/consts.go @@ -0,0 +1,3 @@ +package field_mapping + +const DocFieldNameContent = "eino_doc_content" diff --git a/components/indexer/es8/field_mapping/field_mapping.go b/components/indexer/es8/field_mapping/field_mapping.go new file mode 100644 index 0000000..f87aa47 --- /dev/null +++ b/components/indexer/es8/field_mapping/field_mapping.go @@ -0,0 +1,64 @@ +package field_mapping + +import ( + "fmt" + + "code.byted.org/flow/eino-ext/components/indexer/es8/internal" + "code.byted.org/flow/eino/schema" +) + +// SetExtraDataFields set data fields for es +func SetExtraDataFields(doc *schema.Document, fields map[string]interface{}) { + if doc == nil { + return + } + + if doc.MetaData == nil { + doc.MetaData = make(map[string]any) + } + + doc.MetaData[internal.DocExtraKeyEsFields] = fields +} + +// GetExtraDataFields get data fields from *schema.Document +func GetExtraDataFields(doc *schema.Document) (fields map[string]interface{}, ok bool) { + if doc == nil || doc.MetaData == nil { + return nil, false + } + + fields, ok = doc.MetaData[internal.DocExtraKeyEsFields].(map[string]interface{}) + + return fields, ok +} + +// DefaultFieldKV build default names by fieldName +// docFieldName should be DocFieldNameContent or key got from GetExtraDataFields +func DefaultFieldKV(docFieldName FieldName) FieldKV { + return FieldKV{ + FieldNameVector: FieldName(fmt.Sprintf("vector_%s", docFieldName)), + FieldName: docFieldName, + } +} + +type FieldKV struct { + // FieldNameVector vector field name (if needed) + FieldNameVector FieldName `json:"field_name_vector,omitempty"` + // FieldName field name + FieldName FieldName `json:"field_name,omitempty"` +} + +type FieldName string + +func (v FieldName) Find(doc *schema.Document) (string, bool) { + if v == DocFieldNameContent { + return doc.Content, true + } + + kvs, ok := GetExtraDataFields(doc) + if !ok { + return "", false + } + + s, ok := kvs[string(v)].(string) + return s, ok +} diff --git a/components/indexer/es8/go.mod b/components/indexer/es8/go.mod new file mode 100644 index 0000000..2cae4e2 --- /dev/null +++ b/components/indexer/es8/go.mod @@ -0,0 +1,46 @@ +module code.byted.org/flow/eino-ext/components/indexer/es8 + +go 1.22 + +require ( + code.byted.org/flow/eino v0.2.5 + github.com/bytedance/mockey v1.2.13 + github.com/elastic/go-elasticsearch/v8 v8.16.0 +) + +require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/elastic/elastic-transport-go/v8 v8.6.0 // indirect + github.com/getkin/kin-openapi v0.118.0 // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-openapi/jsonpointer v0.19.5 // indirect + github.com/go-openapi/swag v0.19.5 // indirect + github.com/goph/emperror v0.17.2 // indirect + github.com/gopherjs/gopherjs v1.17.2 // indirect + github.com/invopop/yaml v0.1.0 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect + github.com/nikolalohinski/gonja v1.5.3 // indirect + github.com/pelletier/go-toml/v2 v2.0.9 // indirect + github.com/perimeterx/marshmallow v1.1.4 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect + github.com/smarty/assertions v1.15.0 // indirect + github.com/smartystreets/goconvey v1.8.1 // indirect + github.com/yargevad/filepathx v1.0.0 // indirect + go.opentelemetry.io/otel v1.28.0 // indirect + go.opentelemetry.io/otel/metric v1.28.0 // indirect + go.opentelemetry.io/otel/trace v1.28.0 // indirect + golang.org/x/arch v0.11.0 // indirect + golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect + golang.org/x/sys v0.26.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/components/indexer/es8/go.sum b/components/indexer/es8/go.sum new file mode 100644 index 0000000..5288569 --- /dev/null +++ b/components/indexer/es8/go.sum @@ -0,0 +1,164 @@ +code.byted.org/flow/eino v0.2.5 h1:uPXgTMSfGZZvKhY4c44vt+CPe0FlIQ7/tfFZ/AfF8nI= +code.byted.org/flow/eino v0.2.5/go.mod h1:+o4CnsT/qFrbqhRMBbi70qS7mseWv1md/SD0Jo1kKEA= +github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o= +github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= +github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= +github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= +github.com/bytedance/mockey v1.2.13 h1:jokWZAm/pUEbD939Rhznz615MKUCZNuvCFQlJ2+ntoo= +github.com/bytedance/mockey v1.2.13/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY= +github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/elastic/elastic-transport-go/v8 v8.6.0 h1:Y2S/FBjx1LlCv5m6pWAF2kDJAHoSjSRSJCApolgfthA= +github.com/elastic/elastic-transport-go/v8 v8.6.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk= +github.com/elastic/go-elasticsearch/v8 v8.16.0 h1:f7bR+iBz8GTAVhwyFO3hm4ixsz2eMaEy0QroYnXV3jE= +github.com/elastic/go-elasticsearch/v8 v8.16.0/go.mod h1:lGMlgKIbYoRvay3xWBeKahAiJOgmFDsjZC39nmO3H64= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/getkin/kin-openapi v0.118.0 h1:z43njxPmJ7TaPpMSCQb7PN0dEYno4tyBPQcrFdHoLuM= +github.com/getkin/kin-openapi v0.118.0/go.mod h1:l5e9PaFUo9fyLJCPGQeXI2ML8c3P8BHOEV2VaAVf/pc= +github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/swag v0.19.5 h1:lTz6Ys4CmqqCQmZPBlbQENR1/GucA2bzYTE12Pw4tFY= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= +github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= +github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/invopop/yaml v0.1.0 h1:YW3WGUoJEXYfzWBjn00zIlrw7brGVD0fUKRYDPAPhrc= +github.com/invopop/yaml v0.1.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= +github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= +github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c= +github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= +github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/perimeterx/marshmallow v1.1.4 h1:pZLDH9RjlLGGorbXhcaQLhfuV0pFMNfPO55FuFkxqLw= +github.com/perimeterx/marshmallow v1.1.4/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= +github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= +github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= +github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= +github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= +github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= +github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= +github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= +go.opentelemetry.io/otel v1.28.0 h1:/SqNcYk+idO0CxKEUOtKQClMK/MimZihKYMruSMViUo= +go.opentelemetry.io/otel v1.28.0/go.mod h1:q68ijF8Fc8CnMHKyzqL6akLO46ePnjkgfIMIjUIX9z4= +go.opentelemetry.io/otel/metric v1.28.0 h1:f0HGvSl1KRAU1DLgLGFjrwVyismPlnuU6JD6bOeuA5Q= +go.opentelemetry.io/otel/metric v1.28.0/go.mod h1:Fb1eVBFZmLVTMb6PPohq3TO9IIhUisDsbJoL/+uQW4s= +go.opentelemetry.io/otel/sdk v1.21.0 h1:FTt8qirL1EysG6sTQRZ5TokkU8d0ugCj8htOgThZXQ8= +go.opentelemetry.io/otel/sdk v1.21.0/go.mod h1:Nna6Yv7PWTdgJHVRD9hIYywQBRx7pbox6nwBnZIxl/E= +go.opentelemetry.io/otel/trace v1.28.0 h1:GhQ9cUuQGmNDd5BTCP2dAvv75RdMxEfTmYejp+lkx9g= +go.opentelemetry.io/otel/trace v1.28.0/go.mod h1:jPyXzNPg6da9+38HEwElrQiHlVMTnVfM3/yv2OlIHaI= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +golang.org/x/arch v0.11.0 h1:KXV8WWKCXm6tRpLirl2szsO5j/oOODwZf4hATmGVNs4= +golang.org/x/arch v0.11.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= +golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/components/indexer/es8/indexer.go b/components/indexer/es8/indexer.go new file mode 100644 index 0000000..b028d4b --- /dev/null +++ b/components/indexer/es8/indexer.go @@ -0,0 +1,220 @@ +package es8 + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + + "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/esutil" + + "code.byted.org/flow/eino-ext/components/indexer/es8/field_mapping" + "code.byted.org/flow/eino/callbacks" + "code.byted.org/flow/eino/components" + "code.byted.org/flow/eino/components/embedding" + "code.byted.org/flow/eino/components/indexer" + "code.byted.org/flow/eino/schema" +) + +type IndexerConfig struct { + ESConfig elasticsearch.Config `json:"es_config"` + Index string `json:"index"` + BatchSize int `json:"batch_size"` + + // VectorFields dense_vector field mappings + VectorFields []field_mapping.FieldKV `json:"vector_fields"` + // Embedding vectorization method, must provide in two cases + // 1. VectorFields contains fields except doc Content + // 2. VectorFields contains doc Content and vector not provided in doc extra (see Document.Vector method) + Embedding embedding.Embedder +} + +type Indexer struct { + client *elasticsearch.Client + config *IndexerConfig +} + +func NewIndexer(_ context.Context, conf *IndexerConfig) (*Indexer, error) { + client, err := elasticsearch.NewClient(conf.ESConfig) + if err != nil { + return nil, fmt.Errorf("[NewIndexer] new es client failed, %w", err) + } + + if conf.Embedding == nil { + for _, kv := range conf.VectorFields { + if kv.FieldName != field_mapping.DocFieldNameContent { + return nil, fmt.Errorf("[NewIndexer] Embedding not provided in config, but field kv[%s]-[%s] requires", + kv.FieldNameVector, kv.FieldName) + } + } + } + + if conf.BatchSize == 0 { + conf.BatchSize = defaultBatchSize + } + + return &Indexer{ + client: client, + config: conf, + }, nil +} + +func (i *Indexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) { + defer func() { + if err != nil { + callbacks.OnError(ctx, err) + } + }() + + ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs}) + + options := indexer.GetCommonOptions(&indexer.Options{ + Embedding: i.config.Embedding, + }, opts...) + + bi, err := esutil.NewBulkIndexer(esutil.BulkIndexerConfig{ + Index: i.config.Index, + Client: i.client, + }) + if err != nil { + return nil, err + } + + for _, slice := range chunk(docs, i.config.BatchSize) { + var items []esutil.BulkIndexerItem + + if len(i.config.VectorFields) == 0 { + items, err = i.defaultQueryItems(ctx, slice, options) + } else { + items, err = i.vectorQueryItems(ctx, slice, options) + } + if err != nil { + return nil, err + } + + for _, item := range items { + if err = bi.Add(ctx, item); err != nil { + return nil, err + } + } + } + + if err = bi.Close(ctx); err != nil { + return nil, err + } + + ids = iter(docs, func(t *schema.Document) string { return t.ID }) + + callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: ids}) + + return ids, nil +} + +func (i *Indexer) defaultQueryItems(_ context.Context, docs []*schema.Document, _ *indexer.Options) (items []esutil.BulkIndexerItem, err error) { + items, err = iterWithErr(docs, func(doc *schema.Document) (item esutil.BulkIndexerItem, err error) { + b, err := json.Marshal(toESDoc(doc)) + if err != nil { + return item, err + } + + return esutil.BulkIndexerItem{ + Index: i.config.Index, + Action: "index", + DocumentID: doc.ID, + Body: bytes.NewReader(b), + }, nil + }) + + if err != nil { + return nil, err + } + + return items, nil +} + +func (i *Indexer) vectorQueryItems(ctx context.Context, docs []*schema.Document, options *indexer.Options) (items []esutil.BulkIndexerItem, err error) { + emb := options.Embedding + + items, err = iterWithErr(docs, func(doc *schema.Document) (item esutil.BulkIndexerItem, err error) { + mp := toESDoc(doc) + texts := make([]string, 0, len(i.config.VectorFields)) + for _, kv := range i.config.VectorFields { + str, ok := kv.FieldName.Find(doc) + if !ok { + return item, fmt.Errorf("[vectorQueryItems] field name not found or type incorrect, name=%s, doc=%v", kv.FieldName, doc) + } + + if kv.FieldName == field_mapping.DocFieldNameContent && len(doc.Vector()) > 0 { + mp[string(kv.FieldNameVector)] = doc.Vector() + } else { + texts = append(texts, str) + } + } + + if len(texts) > 0 { + if emb == nil { + return item, fmt.Errorf("[vectorQueryItems] embedding not provided") + } + + vectors, err := emb.EmbedStrings(i.makeEmbeddingCtx(ctx, emb), texts) + if err != nil { + return item, fmt.Errorf("[vectorQueryItems] embedding failed, %w", err) + } + + if len(vectors) != len(texts) { + return item, fmt.Errorf("[vectorQueryItems] invalid vector length, expected=%d, got=%d", len(texts), len(vectors)) + } + + vIdx := 0 + for _, kv := range i.config.VectorFields { + if kv.FieldName == field_mapping.DocFieldNameContent && len(doc.Vector()) > 0 { + continue + } + + mp[string(kv.FieldNameVector)] = vectors[vIdx] + vIdx++ + } + } + + b, err := json.Marshal(mp) + if err != nil { + return item, err + } + + return esutil.BulkIndexerItem{ + Index: i.config.Index, + Action: "index", + DocumentID: doc.ID, + Body: bytes.NewReader(b), + }, nil + }) + + if err != nil { + return nil, err + } + + return items, nil +} + +func (i *Indexer) makeEmbeddingCtx(ctx context.Context, emb embedding.Embedder) context.Context { + runInfo := &callbacks.RunInfo{ + Component: components.ComponentOfEmbedding, + } + + if embType, ok := components.GetType(emb); ok { + runInfo.Type = embType + } + + runInfo.Name = runInfo.Type + string(runInfo.Component) + + return callbacks.SwitchRunInfo(ctx, runInfo) +} + +func (i *Indexer) GetType() string { + return typ +} + +func (i *Indexer) IsCallbacksEnabled() bool { + return true +} diff --git a/components/indexer/es8/indexer_test.go b/components/indexer/es8/indexer_test.go new file mode 100644 index 0000000..2de06d9 --- /dev/null +++ b/components/indexer/es8/indexer_test.go @@ -0,0 +1,132 @@ +package es8 + +import ( + "context" + "fmt" + "io" + "testing" + + . "github.com/bytedance/mockey" + "github.com/smartystreets/goconvey/convey" + + "code.byted.org/flow/eino-ext/components/indexer/es8/field_mapping" + "code.byted.org/flow/eino/components/embedding" + "code.byted.org/flow/eino/components/indexer" + "code.byted.org/flow/eino/schema" +) + +func TestVectorQueryItems(t *testing.T) { + PatchConvey("test vectorQueryItems", t, func() { + ctx := context.Background() + extField := "extra_field" + + d1 := &schema.Document{ID: "123", Content: "asd"} + d1.WithVector([]float64{2.3, 4.4}) + field_mapping.SetExtraDataFields(d1, map[string]interface{}{extField: "ext_1"}) + + d2 := &schema.Document{ID: "456", Content: "qwe"} + field_mapping.SetExtraDataFields(d2, map[string]interface{}{extField: "ext_2"}) + + docs := []*schema.Document{d1, d2} + + PatchConvey("test field not found", func() { + i := &Indexer{ + config: &IndexerConfig{ + Index: "mock_index", + VectorFields: []field_mapping.FieldKV{ + field_mapping.DefaultFieldKV("not_found_field"), + }, + }, + } + + bulks, err := i.vectorQueryItems(ctx, docs, &indexer.Options{ + Embedding: &mockEmbedding{size: []int{1}, mockVector: []float64{2.1}}, + }) + convey.So(err, convey.ShouldBeError, fmt.Sprintf("[vectorQueryItems] field name not found or type incorrect, name=not_found_field, doc=%v", d1)) + convey.So(len(bulks), convey.ShouldEqual, 0) + }) + + PatchConvey("test emb not provided", func() { + i := &Indexer{ + config: &IndexerConfig{ + Index: "mock_index", + VectorFields: []field_mapping.FieldKV{ + field_mapping.DefaultFieldKV(field_mapping.DocFieldNameContent), + field_mapping.DefaultFieldKV(field_mapping.FieldName(extField)), + }, + }, + } + + bulks, err := i.vectorQueryItems(ctx, docs, &indexer.Options{Embedding: nil}) + convey.So(err, convey.ShouldBeError, "[vectorQueryItems] embedding not provided") + convey.So(len(bulks), convey.ShouldEqual, 0) + }) + + PatchConvey("test vector size invalid", func() { + i := &Indexer{ + config: &IndexerConfig{ + Index: "mock_index", + VectorFields: []field_mapping.FieldKV{ + field_mapping.DefaultFieldKV(field_mapping.DocFieldNameContent), + field_mapping.DefaultFieldKV(field_mapping.FieldName(extField)), + }, + }, + } + + bulks, err := i.vectorQueryItems(ctx, docs, &indexer.Options{ + Embedding: &mockEmbedding{size: []int{2, 2}, mockVector: []float64{2.1}}, + }) + convey.So(err, convey.ShouldBeError, "[vectorQueryItems] invalid vector length, expected=1, got=2") + convey.So(len(bulks), convey.ShouldEqual, 0) + }) + + PatchConvey("test success", func() { + i := &Indexer{ + config: &IndexerConfig{ + Index: "mock_index", + VectorFields: []field_mapping.FieldKV{ + field_mapping.DefaultFieldKV(field_mapping.DocFieldNameContent), + field_mapping.DefaultFieldKV(field_mapping.FieldName(extField)), + }, + }, + } + + bulks, err := i.vectorQueryItems(ctx, docs, &indexer.Options{ + Embedding: &mockEmbedding{size: []int{1, 2}, mockVector: []float64{2.1}}, + }) + convey.So(err, convey.ShouldBeNil) + convey.So(len(bulks), convey.ShouldEqual, 2) + exp := []string{ + `{"eino_doc_content":"asd","extra_field":"ext_1","vector_eino_doc_content":[2.3,4.4],"vector_extra_field":[2.1]}`, + `{"eino_doc_content":"qwe","extra_field":"ext_2","vector_eino_doc_content":[2.1],"vector_extra_field":[2.1]}`, + } + + for idx, item := range bulks { + convey.So(item.Index, convey.ShouldEqual, i.config.Index) + b, err := io.ReadAll(item.Body) + convey.So(err, convey.ShouldBeNil) + convey.So(string(b), convey.ShouldEqual, exp[idx]) + } + }) + }) +} + +type mockEmbedding struct { + call int + size []int + mockVector []float64 +} + +func (m *mockEmbedding) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { + if m.call >= len(m.size) { + return nil, fmt.Errorf("call limit error") + } + + resp := make([][]float64, m.size[m.call]) + m.call++ + for i := range resp { + resp[i] = m.mockVector + } + + return resp, nil +} diff --git a/components/indexer/es8/internal/consts.go b/components/indexer/es8/internal/consts.go new file mode 100644 index 0000000..e298275 --- /dev/null +++ b/components/indexer/es8/internal/consts.go @@ -0,0 +1,5 @@ +package internal + +const ( + DocExtraKeyEsFields = "_es_fields" // *schema.Document.MetaData key of es fields except content +) diff --git a/components/indexer/es8/utils.go b/components/indexer/es8/utils.go new file mode 100644 index 0000000..9bb61da --- /dev/null +++ b/components/indexer/es8/utils.go @@ -0,0 +1,63 @@ +package es8 + +import ( + "code.byted.org/flow/eino-ext/components/indexer/es8/field_mapping" + "code.byted.org/flow/eino/schema" +) + +func GetType() string { + return typ +} + +func toESDoc(doc *schema.Document) map[string]any { + mp := make(map[string]any) + if kvs, ok := field_mapping.GetExtraDataFields(doc); ok { + for k, v := range kvs { + mp[k] = v + } + } + + mp[field_mapping.DocFieldNameContent] = doc.Content + + return mp +} + +func chunk[T any](slice []T, size int) [][]T { + if size <= 0 { + return nil + } + + var chunks [][]T + for size < len(slice) { + slice, chunks = slice[size:], append(chunks, slice[0:size:size]) + } + + if len(slice) > 0 { + chunks = append(chunks, slice) + } + + return chunks +} + +func iter[T, D any](src []T, fn func(T) D) []D { + resp := make([]D, len(src)) + for i := range src { + resp[i] = fn(src[i]) + } + + return resp +} + +func iterWithErr[T, D any](src []T, fn func(T) (D, error)) ([]D, error) { + resp := make([]D, 0, len(src)) + for i := range src { + d, err := fn(src[i]) + if err != nil { + return nil, err + } + + resp = append(resp, d) + } + + return resp, nil +} diff --git a/components/retriever/es8/consts.go b/components/retriever/es8/consts.go new file mode 100644 index 0000000..ec1211c --- /dev/null +++ b/components/retriever/es8/consts.go @@ -0,0 +1,3 @@ +package es8 + +const typ = "ElasticSearch8" diff --git a/components/retriever/es8/field_mapping/consts.go b/components/retriever/es8/field_mapping/consts.go new file mode 100644 index 0000000..80438fa --- /dev/null +++ b/components/retriever/es8/field_mapping/consts.go @@ -0,0 +1,3 @@ +package field_mapping + +const DocFieldNameContent = "eino_doc_content" diff --git a/components/retriever/es8/field_mapping/mapping.go b/components/retriever/es8/field_mapping/mapping.go new file mode 100644 index 0000000..6db7b02 --- /dev/null +++ b/components/retriever/es8/field_mapping/mapping.go @@ -0,0 +1,42 @@ +package field_mapping + +import ( + "fmt" + + "code.byted.org/flow/eino-ext/components/retriever/es8/internal" + "code.byted.org/flow/eino/schema" +) + +// GetDefaultVectorFieldKeyContent get default es key for Document.Content +func GetDefaultVectorFieldKeyContent() FieldName { + return defaultVectorFieldKeyContent +} + +// GetDefaultVectorFieldKey generate default vector field name from its field name +func GetDefaultVectorFieldKey(fieldName string) FieldName { + return FieldName(fmt.Sprintf("vector_%s", fieldName)) +} + +// GetExtraDataFields get data fields from *schema.Document +func GetExtraDataFields(doc *schema.Document) (fields map[string]interface{}, ok bool) { + if doc == nil || doc.MetaData == nil { + return nil, false + } + + fields, ok = doc.MetaData[internal.DocExtraKeyEsFields].(map[string]interface{}) + + return fields, ok +} + +type FieldKV struct { + // FieldNameVector vector field name (if needed) + FieldNameVector FieldName `json:"field_name_vector,omitempty"` + // FieldName field name + FieldName FieldName `json:"field_name,omitempty"` + // Value original value + Value string `json:"value,omitempty"` +} + +type FieldName string + +var defaultVectorFieldKeyContent = GetDefaultVectorFieldKey(DocFieldNameContent) diff --git a/components/retriever/es8/go.mod b/components/retriever/es8/go.mod new file mode 100644 index 0000000..77fbaca --- /dev/null +++ b/components/retriever/es8/go.mod @@ -0,0 +1,46 @@ +module code.byted.org/flow/eino-ext/components/retriever/es8 + +go 1.22 + +require ( + code.byted.org/flow/eino v0.2.5 + github.com/bytedance/mockey v1.2.13 + github.com/elastic/go-elasticsearch/v8 v8.16.0 +) + +require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/elastic/elastic-transport-go/v8 v8.6.0 // indirect + github.com/getkin/kin-openapi v0.118.0 // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-openapi/jsonpointer v0.19.5 // indirect + github.com/go-openapi/swag v0.19.5 // indirect + github.com/goph/emperror v0.17.2 // indirect + github.com/gopherjs/gopherjs v1.17.2 // indirect + github.com/invopop/yaml v0.1.0 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect + github.com/nikolalohinski/gonja v1.5.3 // indirect + github.com/pelletier/go-toml/v2 v2.0.9 // indirect + github.com/perimeterx/marshmallow v1.1.4 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect + github.com/smarty/assertions v1.15.0 // indirect + github.com/smartystreets/goconvey v1.8.1 // indirect + github.com/yargevad/filepathx v1.0.0 // indirect + go.opentelemetry.io/otel v1.28.0 // indirect + go.opentelemetry.io/otel/metric v1.28.0 // indirect + go.opentelemetry.io/otel/trace v1.28.0 // indirect + golang.org/x/arch v0.11.0 // indirect + golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect + golang.org/x/sys v0.26.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/components/retriever/es8/go.sum b/components/retriever/es8/go.sum new file mode 100644 index 0000000..5288569 --- /dev/null +++ b/components/retriever/es8/go.sum @@ -0,0 +1,164 @@ +code.byted.org/flow/eino v0.2.5 h1:uPXgTMSfGZZvKhY4c44vt+CPe0FlIQ7/tfFZ/AfF8nI= +code.byted.org/flow/eino v0.2.5/go.mod h1:+o4CnsT/qFrbqhRMBbi70qS7mseWv1md/SD0Jo1kKEA= +github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o= +github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= +github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= +github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= +github.com/bytedance/mockey v1.2.13 h1:jokWZAm/pUEbD939Rhznz615MKUCZNuvCFQlJ2+ntoo= +github.com/bytedance/mockey v1.2.13/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY= +github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/elastic/elastic-transport-go/v8 v8.6.0 h1:Y2S/FBjx1LlCv5m6pWAF2kDJAHoSjSRSJCApolgfthA= +github.com/elastic/elastic-transport-go/v8 v8.6.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk= +github.com/elastic/go-elasticsearch/v8 v8.16.0 h1:f7bR+iBz8GTAVhwyFO3hm4ixsz2eMaEy0QroYnXV3jE= +github.com/elastic/go-elasticsearch/v8 v8.16.0/go.mod h1:lGMlgKIbYoRvay3xWBeKahAiJOgmFDsjZC39nmO3H64= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/getkin/kin-openapi v0.118.0 h1:z43njxPmJ7TaPpMSCQb7PN0dEYno4tyBPQcrFdHoLuM= +github.com/getkin/kin-openapi v0.118.0/go.mod h1:l5e9PaFUo9fyLJCPGQeXI2ML8c3P8BHOEV2VaAVf/pc= +github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/swag v0.19.5 h1:lTz6Ys4CmqqCQmZPBlbQENR1/GucA2bzYTE12Pw4tFY= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= +github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= +github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/invopop/yaml v0.1.0 h1:YW3WGUoJEXYfzWBjn00zIlrw7brGVD0fUKRYDPAPhrc= +github.com/invopop/yaml v0.1.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= +github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= +github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c= +github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= +github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/perimeterx/marshmallow v1.1.4 h1:pZLDH9RjlLGGorbXhcaQLhfuV0pFMNfPO55FuFkxqLw= +github.com/perimeterx/marshmallow v1.1.4/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= +github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= +github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= +github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= +github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= +github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= +github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= +github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= +go.opentelemetry.io/otel v1.28.0 h1:/SqNcYk+idO0CxKEUOtKQClMK/MimZihKYMruSMViUo= +go.opentelemetry.io/otel v1.28.0/go.mod h1:q68ijF8Fc8CnMHKyzqL6akLO46ePnjkgfIMIjUIX9z4= +go.opentelemetry.io/otel/metric v1.28.0 h1:f0HGvSl1KRAU1DLgLGFjrwVyismPlnuU6JD6bOeuA5Q= +go.opentelemetry.io/otel/metric v1.28.0/go.mod h1:Fb1eVBFZmLVTMb6PPohq3TO9IIhUisDsbJoL/+uQW4s= +go.opentelemetry.io/otel/sdk v1.21.0 h1:FTt8qirL1EysG6sTQRZ5TokkU8d0ugCj8htOgThZXQ8= +go.opentelemetry.io/otel/sdk v1.21.0/go.mod h1:Nna6Yv7PWTdgJHVRD9hIYywQBRx7pbox6nwBnZIxl/E= +go.opentelemetry.io/otel/trace v1.28.0 h1:GhQ9cUuQGmNDd5BTCP2dAvv75RdMxEfTmYejp+lkx9g= +go.opentelemetry.io/otel/trace v1.28.0/go.mod h1:jPyXzNPg6da9+38HEwElrQiHlVMTnVfM3/yv2OlIHaI= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +golang.org/x/arch v0.11.0 h1:KXV8WWKCXm6tRpLirl2szsO5j/oOODwZf4hATmGVNs4= +golang.org/x/arch v0.11.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= +golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/components/retriever/es8/internal/consts.go b/components/retriever/es8/internal/consts.go new file mode 100644 index 0000000..27aa995 --- /dev/null +++ b/components/retriever/es8/internal/consts.go @@ -0,0 +1,6 @@ +package internal + +const ( + DocExtraKeyEsFields = "_es_fields" // *schema.Document.MetaData key of es fields except content + DslFilterField = "_dsl_filter_functions" +) diff --git a/components/retriever/es8/retriever.go b/components/retriever/es8/retriever.go new file mode 100644 index 0000000..46fe8f0 --- /dev/null +++ b/components/retriever/es8/retriever.go @@ -0,0 +1,150 @@ +package es8 + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" + + "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" + "code.byted.org/flow/eino-ext/components/retriever/es8/internal" + "code.byted.org/flow/eino-ext/components/retriever/es8/search_mode" + "code.byted.org/flow/eino/callbacks" + "code.byted.org/flow/eino/components/embedding" + "code.byted.org/flow/eino/components/retriever" + "code.byted.org/flow/eino/schema" +) + +type RetrieverConfig struct { + ESConfig elasticsearch.Config `json:"es_config"` + + Index string `json:"index"` + TopK int `json:"top_k"` + ScoreThreshold *float64 `json:"score_threshold"` + + // SearchMode retrieve strategy, see prepared impls in search_mode package: + // use search_mode.SearchModeExactMatch with string query + // use search_mode.SearchModeApproximate with search_mode.ApproximateQuery + // use search_mode.SearchModeDenseVectorSimilarity with search_mode.DenseVectorSimilarityQuery + // use search_mode.SearchModeSparseVectorTextExpansion with search_mode.SparseVectorTextExpansionQuery + // use search_mode.SearchModeRawStringRequest with json search request + SearchMode search_mode.SearchMode `json:"search_mode"` + // Embedding vectorization method, must provide when SearchMode needed + Embedding embedding.Embedder +} + +type Retriever struct { + client *elasticsearch.TypedClient + config *RetrieverConfig +} + +func NewRetriever(_ context.Context, conf *RetrieverConfig) (*Retriever, error) { + if conf.SearchMode == nil { + return nil, fmt.Errorf("[NewRetriever] search mode not provided") + } + + client, err := elasticsearch.NewTypedClient(conf.ESConfig) + if err != nil { + return nil, fmt.Errorf("[NewRetriever] new es client failed, %w", err) + } + + return &Retriever{ + client: client, + config: conf, + }, nil +} + +func (r *Retriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (docs []*schema.Document, err error) { + defer func() { + if err != nil { + callbacks.OnError(ctx, err) + } + }() + + options := retriever.GetCommonOptions(&retriever.Options{ + Index: &r.config.Index, + TopK: &r.config.TopK, + ScoreThreshold: r.config.ScoreThreshold, + Embedding: r.config.Embedding, + }, opts...) + + ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{ + Query: query, + TopK: *options.TopK, + ScoreThreshold: options.ScoreThreshold, + }) + + req, err := r.config.SearchMode.BuildRequest(ctx, query, options) + if err != nil { + return nil, err + } + + resp, err := r.client.Search(). + Index(r.config.Index). + Request(req). + Do(ctx) + if err != nil { + return nil, err + } + + docs, err = r.parseSearchResult(resp) + if err != nil { + return nil, err + } + + callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: docs}) + + return docs, nil +} + +func (r *Retriever) parseSearchResult(resp *search.Response) (docs []*schema.Document, err error) { + docs = make([]*schema.Document, 0, len(resp.Hits.Hits)) + + for _, hit := range resp.Hits.Hits { + var raw map[string]any + if err = json.Unmarshal(hit.Source_, &raw); err != nil { + return nil, fmt.Errorf("[parseSearchResult] unexpected hit source type, source=%v", string(hit.Source_)) + } + + var id string + if hit.Id_ != nil { + id = *hit.Id_ + } + + content, ok := raw[field_mapping.DocFieldNameContent].(string) + if !ok { + return nil, fmt.Errorf("[parseSearchResult] content type not string, raw=%v", raw) + } + + expMap := make(map[string]any, len(raw)-1) + for k, v := range raw { + if k != internal.DocExtraKeyEsFields { + expMap[k] = v + } + } + + doc := &schema.Document{ + ID: id, + Content: content, + MetaData: map[string]any{internal.DocExtraKeyEsFields: expMap}, + } + + if hit.Score_ != nil { + doc.WithScore(float64(*hit.Score_)) + } + + docs = append(docs, doc) + } + + return docs, nil +} + +func (r *Retriever) GetType() string { + return typ +} + +func (r *Retriever) IsCallbacksEnabled() bool { + return true +} diff --git a/components/retriever/es8/search_mode/approximate.go b/components/retriever/es8/search_mode/approximate.go new file mode 100644 index 0000000..65c7882 --- /dev/null +++ b/components/retriever/es8/search_mode/approximate.go @@ -0,0 +1,140 @@ +package search_mode + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + + "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" + "code.byted.org/flow/eino/components/retriever" +) + +// SearchModeApproximate retrieve with multiple approximate strategy (filter+knn+rrf) +// knn: https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html +// rrf: https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html +func SearchModeApproximate(config *ApproximateConfig) SearchMode { + return &approximate{config} +} + +type ApproximateConfig struct { + // Hybrid if true, add filters and rff to knn query + Hybrid bool + // Rrf is a method for combining multiple result sets, is used to + // even the score from the knn query and text query + Rrf bool + // RrfRankConstant determines how much influence documents in + // individual result sets per query have over the final ranked result set + RrfRankConstant *int64 + // RrfWindowSize determines the size of the individual result sets per query + RrfWindowSize *int64 +} + +type ApproximateQuery struct { + // FieldKV es field info, QueryVectorBuilderModelID will be used if embedding not provided in config, + // and Embedding will be used if QueryVectorBuilderModelID is nil + FieldKV field_mapping.FieldKV `json:"field_kv"` + // QueryVectorBuilderModelID the query vector builder model id + // see: https://www.elastic.co/guide/en/machine-learning/8.16/ml-nlp-text-emb-vector-search-example.html + QueryVectorBuilderModelID *string `json:"query_vector_builder_model_id,omitempty"` + // Boost Floating point number used to decrease or increase the relevance scores of the query. + // Boost values are relative to the default value of 1.0. + // A boost value between 0 and 1.0 decreases the relevance score. + // A value greater than 1.0 increases the relevance score. + Boost *float32 `json:"boost,omitempty"` + // Filters for the kNN search query + Filters []types.Query `json:"filters,omitempty"` + // K The final number of nearest neighbors to return as top hits + K *int `json:"k,omitempty"` + // NumCandidates The number of nearest neighbor candidates to consider per shard + NumCandidates *int `json:"num_candidates,omitempty"` + // Similarity The minimum similarity for a vector to be considered a match + Similarity *float32 `json:"similarity,omitempty"` +} + +// ToRetrieverQuery convert approximate query to string query +func (a *ApproximateQuery) ToRetrieverQuery() (string, error) { + b, err := json.Marshal(a) + if err != nil { + return "", fmt.Errorf("[ToRetrieverQuery] convert query failed, %w", err) + } + + return string(b), nil +} + +type approximate struct { + config *ApproximateConfig +} + +func (a *approximate) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) { + var appReq ApproximateQuery + if err := json.Unmarshal([]byte(query), &appReq); err != nil { + return nil, fmt.Errorf("[BuildRequest][SearchModeApproximate] parse query failed, %w", err) + } + + knn := types.KnnSearch{ + Boost: appReq.Boost, + Field: string(appReq.FieldKV.FieldNameVector), + Filter: appReq.Filters, + K: appReq.K, + NumCandidates: appReq.NumCandidates, + QueryVector: nil, + QueryVectorBuilder: nil, + Similarity: appReq.Similarity, + } + + if appReq.QueryVectorBuilderModelID != nil { + knn.QueryVectorBuilder = &types.QueryVectorBuilder{TextEmbedding: &types.TextEmbedding{ + ModelId: *appReq.QueryVectorBuilderModelID, + ModelText: appReq.FieldKV.Value, + }} + } else { + emb := options.Embedding + if emb == nil { + return nil, fmt.Errorf("[BuildRequest][SearchModeApproximate] embedding not provided") + } + + vector, err := emb.EmbedStrings(makeEmbeddingCtx(ctx, emb), []string{appReq.FieldKV.Value}) + if err != nil { + return nil, fmt.Errorf("[BuildRequest][SearchModeApproximate] embedding failed, %w", err) + } + + if len(vector) != 1 { + return nil, fmt.Errorf("[BuildRequest][SearchModeApproximate] vector len error, expected=1, got=%d", len(vector)) + } + + knn.QueryVector = f64To32(vector[0]) + } + + req := &search.Request{Knn: []types.KnnSearch{knn}, Size: options.TopK} + + if a.config.Hybrid { + req.Query = &types.Query{ + Bool: &types.BoolQuery{ + Filter: appReq.Filters, + Must: []types.Query{ + { + Match: map[string]types.MatchQuery{ + string(appReq.FieldKV.FieldName): {Query: appReq.FieldKV.Value}, + }, + }, + }, + }, + } + + if a.config.Rrf { + req.Rank = &types.RankContainer{Rrf: &types.RrfRank{ + RankConstant: a.config.RrfRankConstant, + RankWindowSize: a.config.RrfWindowSize, + }} + } + } + + if options.ScoreThreshold != nil { + req.MinScore = (*types.Float64)(of(*options.ScoreThreshold)) + } + + return req, nil +} diff --git a/components/retriever/es8/search_mode/approximate_test.go b/components/retriever/es8/search_mode/approximate_test.go new file mode 100644 index 0000000..ffae089 --- /dev/null +++ b/components/retriever/es8/search_mode/approximate_test.go @@ -0,0 +1,148 @@ +package search_mode + +import ( + "context" + "encoding/json" + "testing" + + . "github.com/bytedance/mockey" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/smartystreets/goconvey/convey" + + "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" + "code.byted.org/flow/eino/components/embedding" + "code.byted.org/flow/eino/components/retriever" +) + +func TestSearchModeApproximate(t *testing.T) { + PatchConvey("test SearchModeApproximate", t, func() { + PatchConvey("test ToRetrieverQuery", func() { + aq := &ApproximateQuery{ + FieldKV: field_mapping.FieldKV{ + FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), + FieldName: field_mapping.DocFieldNameContent, + Value: "content", + }, + Filters: []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + }, + Boost: of(float32(1.0)), + K: of(10), + NumCandidates: of(100), + Similarity: of(float32(0.5)), + } + + sq, err := aq.ToRetrieverQuery() + convey.So(err, convey.ShouldBeNil) + convey.So(sq, convey.ShouldEqual, `{"field_kv":{"field_name_vector":"vector_eino_doc_content","field_name":"eino_doc_content","value":"content"},"boost":1,"filters":[{"match":{"label":{"query":"good"}}}],"k":10,"num_candidates":100,"similarity":0.5}`) + }) + + PatchConvey("test BuildRequest", func() { + ctx := context.Background() + + PatchConvey("test QueryVectorBuilderModelID", func() { + a := &approximate{config: &ApproximateConfig{}} + aq := &ApproximateQuery{ + FieldKV: field_mapping.FieldKV{ + FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), + FieldName: field_mapping.DocFieldNameContent, + Value: "content", + }, + QueryVectorBuilderModelID: of("mock_model"), + Filters: []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + }, + Boost: of(float32(1.0)), + K: of(10), + NumCandidates: of(100), + Similarity: of(float32(0.5)), + } + + sq, err := aq.ToRetrieverQuery() + convey.So(err, convey.ShouldBeNil) + + req, err := a.BuildRequest(ctx, sq, &retriever.Options{Embedding: nil}) + convey.So(err, convey.ShouldBeNil) + b, err := json.Marshal(req) + convey.So(err, convey.ShouldBeNil) + convey.So(string(b), convey.ShouldEqual, `{"knn":[{"boost":1,"field":"vector_eino_doc_content","filter":[{"match":{"label":{"query":"good"}}}],"k":10,"num_candidates":100,"query_vector_builder":{"text_embedding":{"model_id":"mock_model","model_text":"content"}},"similarity":0.5}]}`) + }) + + PatchConvey("test embedding", func() { + a := &approximate{config: &ApproximateConfig{}} + aq := &ApproximateQuery{ + FieldKV: field_mapping.FieldKV{ + FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), + FieldName: field_mapping.DocFieldNameContent, + Value: "content", + }, + Filters: []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + }, + Boost: of(float32(1.0)), + K: of(10), + NumCandidates: of(100), + Similarity: of(float32(0.5)), + } + + sq, err := aq.ToRetrieverQuery() + convey.So(err, convey.ShouldBeNil) + req, err := a.BuildRequest(ctx, sq, &retriever.Options{Embedding: &mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}}) + convey.So(err, convey.ShouldBeNil) + b, err := json.Marshal(req) + convey.So(err, convey.ShouldBeNil) + convey.So(string(b), convey.ShouldEqual, `{"knn":[{"boost":1,"field":"vector_eino_doc_content","filter":[{"match":{"label":{"query":"good"}}}],"k":10,"num_candidates":100,"query_vector":[1.1,1.2],"similarity":0.5}]}`) + }) + + PatchConvey("test hybrid with rrf", func() { + a := &approximate{config: &ApproximateConfig{ + Hybrid: true, + Rrf: true, + RrfRankConstant: of(int64(10)), + RrfWindowSize: of(int64(5)), + }} + + aq := &ApproximateQuery{ + FieldKV: field_mapping.FieldKV{ + FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), + FieldName: field_mapping.DocFieldNameContent, + Value: "content", + }, + Filters: []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + }, + Boost: of(float32(1.0)), + K: of(10), + NumCandidates: of(100), + Similarity: of(float32(0.5)), + } + + sq, err := aq.ToRetrieverQuery() + convey.So(err, convey.ShouldBeNil) + req, err := a.BuildRequest(ctx, sq, &retriever.Options{ + Embedding: &mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}, + TopK: of(10), + ScoreThreshold: of(1.1), + }) + convey.So(err, convey.ShouldBeNil) + b, err := json.Marshal(req) + convey.So(err, convey.ShouldBeNil) + convey.So(string(b), convey.ShouldEqual, `{"knn":[{"boost":1,"field":"vector_eino_doc_content","filter":[{"match":{"label":{"query":"good"}}}],"k":10,"num_candidates":100,"query_vector":[1.1,1.2],"similarity":0.5}],"min_score":1.1,"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}],"must":[{"match":{"eino_doc_content":{"query":"content"}}}]}},"rank":{"rrf":{"rank_constant":10,"rank_window_size":5}},"size":10}`) + }) + }) + }) +} + +type mockEmbedding struct { + size int + mockVector []float64 +} + +func (m mockEmbedding) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { + resp := make([][]float64, m.size) + for i := range resp { + resp[i] = m.mockVector + } + + return resp, nil +} diff --git a/components/retriever/es8/search_mode/dense_vector_similarity.go b/components/retriever/es8/search_mode/dense_vector_similarity.go new file mode 100644 index 0000000..f78b456 --- /dev/null +++ b/components/retriever/es8/search_mode/dense_vector_similarity.go @@ -0,0 +1,109 @@ +package search_mode + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + + "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" + "code.byted.org/flow/eino/components/retriever" +) + +// SearchModeDenseVectorSimilarity calculate embedding similarity between dense_vector field and query +// see: https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html#vector-functions +func SearchModeDenseVectorSimilarity(typ DenseVectorSimilarityType) SearchMode { + return &denseVectorSimilarity{script: denseVectorScriptMap[typ]} +} + +type DenseVectorSimilarityQuery struct { + FieldKV field_mapping.FieldKV `json:"field_kv"` + Filters []types.Query `json:"filters,omitempty"` +} + +// ToRetrieverQuery convert approximate query to string query +func (d *DenseVectorSimilarityQuery) ToRetrieverQuery() (string, error) { + b, err := json.Marshal(d) + if err != nil { + return "", fmt.Errorf("[ToRetrieverQuery] convert query failed, %w", err) + } + + return string(b), nil +} + +type denseVectorSimilarity struct { + script string +} + +func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) { + var dq DenseVectorSimilarityQuery + if err := json.Unmarshal([]byte(query), &dq); err != nil { + return nil, fmt.Errorf("[BuildRequest][SearchModeDenseVectorSimilarity] parse query failed, %w", err) + } + + emb := options.Embedding + if emb == nil { + return nil, fmt.Errorf("[BuildRequest][SearchModeDenseVectorSimilarity] embedding not provided") + } + + vector, err := emb.EmbedStrings(makeEmbeddingCtx(ctx, emb), []string{dq.FieldKV.Value}) + if err != nil { + return nil, fmt.Errorf("[BuildRequest][SearchModeDenseVectorSimilarity] embedding failed, %w", err) + } + + if len(vector) != 1 { + return nil, fmt.Errorf("[BuildRequest][SearchModeDenseVectorSimilarity] vector size invalid, expect=1, got=%d", len(vector)) + } + + vb, err := json.Marshal(vector[0]) + if err != nil { + return nil, fmt.Errorf("[BuildRequest][SearchModeDenseVectorSimilarity] marshal vector to bytes failed, %w", err) + } + + q := &types.Query{ + ScriptScore: &types.ScriptScoreQuery{ + Script: types.Script{ + Source: of(fmt.Sprintf(d.script, dq.FieldKV.FieldNameVector)), + Params: map[string]json.RawMessage{"embedding": vb}, + }, + }, + } + + if len(dq.Filters) > 0 { + q.ScriptScore.Query = &types.Query{ + Bool: &types.BoolQuery{Filter: dq.Filters}, + } + } else { + q.ScriptScore.Query = &types.Query{ + MatchAll: &types.MatchAllQuery{}, + } + } + + req := &search.Request{Query: q, Size: options.TopK} + if options.ScoreThreshold != nil { + req.MinScore = (*types.Float64)(of(*options.ScoreThreshold)) + } + + return req, nil +} + +type DenseVectorSimilarityType string + +const ( + DenseVectorSimilarityTypeCosineSimilarity DenseVectorSimilarityType = "cosineSimilarity" + DenseVectorSimilarityTypeDotProduct DenseVectorSimilarityType = "dotProduct" + DenseVectorSimilarityTypeL1Norm DenseVectorSimilarityType = "l1norm" + DenseVectorSimilarityTypeL2Norm DenseVectorSimilarityType = "l2norm" +) + +var denseVectorScriptMap = map[DenseVectorSimilarityType]string{ + DenseVectorSimilarityTypeCosineSimilarity: `cosineSimilarity(params.embedding, '%s') + 1.0`, + DenseVectorSimilarityTypeDotProduct: `"" + double value = dotProduct(params.embedding, '%s'); + return sigmoid(1, Math.E, -value); + ""`, + DenseVectorSimilarityTypeL1Norm: `1 / (1 + l1norm(params.embedding, '%s'))`, + DenseVectorSimilarityTypeL2Norm: `1 / (1 + l2norm(params.embedding, '%s'))`, +} diff --git a/components/retriever/es8/search_mode/dense_vector_similarity_test.go b/components/retriever/es8/search_mode/dense_vector_similarity_test.go new file mode 100644 index 0000000..4742248 --- /dev/null +++ b/components/retriever/es8/search_mode/dense_vector_similarity_test.go @@ -0,0 +1,87 @@ +package search_mode + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + . "github.com/bytedance/mockey" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/smartystreets/goconvey/convey" + + "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" + "code.byted.org/flow/eino/components/retriever" +) + +func TestSearchModeDenseVectorSimilarity(t *testing.T) { + PatchConvey("test SearchModeDenseVectorSimilarity", t, func() { + PatchConvey("test ToRetrieverQuery", func() { + dq := &DenseVectorSimilarityQuery{ + FieldKV: field_mapping.FieldKV{ + FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), + FieldName: field_mapping.DocFieldNameContent, + Value: "content", + }, + Filters: []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + }, + } + + sq, err := dq.ToRetrieverQuery() + convey.So(err, convey.ShouldBeNil) + convey.So(sq, convey.ShouldEqual, `{"field_kv":{"field_name_vector":"vector_eino_doc_content","field_name":"eino_doc_content","value":"content"},"filters":[{"match":{"label":{"query":"good"}}}]}`) + }) + + PatchConvey("test BuildRequest", func() { + ctx := context.Background() + d := &denseVectorSimilarity{script: denseVectorScriptMap[DenseVectorSimilarityTypeCosineSimilarity]} + dq := &DenseVectorSimilarityQuery{ + FieldKV: field_mapping.FieldKV{ + FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), + FieldName: field_mapping.DocFieldNameContent, + Value: "content", + }, + Filters: []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + }, + } + sq, _ := dq.ToRetrieverQuery() + + PatchConvey("test embedding not provided", func() { + req, err := d.BuildRequest(ctx, sq, &retriever.Options{Embedding: nil}) + convey.So(err, convey.ShouldBeError, "[BuildRequest][SearchModeDenseVectorSimilarity] embedding not provided") + convey.So(req, convey.ShouldBeNil) + }) + + PatchConvey("test vector size invalid", func() { + req, err := d.BuildRequest(ctx, sq, &retriever.Options{Embedding: mockEmbedding{size: 2, mockVector: []float64{1.1, 1.2}}}) + convey.So(err, convey.ShouldBeError, "[BuildRequest][SearchModeDenseVectorSimilarity] vector size invalid, expect=1, got=2") + convey.So(req, convey.ShouldBeNil) + }) + + PatchConvey("test success", func() { + typ2Exp := map[DenseVectorSimilarityType]string{ + DenseVectorSimilarityTypeCosineSimilarity: `{"min_score":1.1,"query":{"script_score":{"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}]}},"script":{"params":{"embedding":[1.1,1.2]},"source":"cosineSimilarity(params.embedding, 'vector_eino_doc_content') + 1.0"}}},"size":10}`, + DenseVectorSimilarityTypeDotProduct: `{"min_score":1.1,"query":{"script_score":{"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}]}},"script":{"params":{"embedding":[1.1,1.2]},"source":"\"\"\n double value = dotProduct(params.embedding, 'vector_eino_doc_content');\n return sigmoid(1, Math.E, -value); \n \"\""}}},"size":10}`, + DenseVectorSimilarityTypeL1Norm: `{"min_score":1.1,"query":{"script_score":{"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}]}},"script":{"params":{"embedding":[1.1,1.2]},"source":"1 / (1 + l1norm(params.embedding, 'vector_eino_doc_content'))"}}},"size":10}`, + DenseVectorSimilarityTypeL2Norm: `{"min_score":1.1,"query":{"script_score":{"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}]}},"script":{"params":{"embedding":[1.1,1.2]},"source":"1 / (1 + l2norm(params.embedding, 'vector_eino_doc_content'))"}}},"size":10}`, + } + + for typ, exp := range typ2Exp { + nd := &denseVectorSimilarity{script: denseVectorScriptMap[typ]} + req, err := nd.BuildRequest(ctx, sq, &retriever.Options{ + Embedding: mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}, + TopK: of(10), + ScoreThreshold: of(1.1), + }) + convey.So(err, convey.ShouldBeNil) + b, err := json.Marshal(req) + convey.So(err, convey.ShouldBeNil) + fmt.Println(string(b)) + convey.So(string(b), convey.ShouldEqual, exp) + } + }) + }) + }) +} diff --git a/components/retriever/es8/search_mode/exact_match.go b/components/retriever/es8/search_mode/exact_match.go new file mode 100644 index 0000000..048b46f --- /dev/null +++ b/components/retriever/es8/search_mode/exact_match.go @@ -0,0 +1,32 @@ +package search_mode + +import ( + "context" + + "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + + "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" + "code.byted.org/flow/eino/components/retriever" +) + +func SearchModeExactMatch() SearchMode { + return &exactMatch{} +} + +type exactMatch struct{} + +func (e exactMatch) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) { + q := &types.Query{ + Match: map[string]types.MatchQuery{ + field_mapping.DocFieldNameContent: {Query: query}, + }, + } + + req := &search.Request{Query: q, Size: options.TopK} + if options.ScoreThreshold != nil { + req.MinScore = (*types.Float64)(of(*options.ScoreThreshold)) + } + + return req, nil +} diff --git a/components/retriever/es8/search_mode/interface.go b/components/retriever/es8/search_mode/interface.go new file mode 100644 index 0000000..34ad08f --- /dev/null +++ b/components/retriever/es8/search_mode/interface.go @@ -0,0 +1,13 @@ +package search_mode + +import ( + "context" + + "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" + + "code.byted.org/flow/eino/components/retriever" +) + +type SearchMode interface { // nolint: byted_s_interface_name + BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) +} diff --git a/components/retriever/es8/search_mode/raw_string.go b/components/retriever/es8/search_mode/raw_string.go new file mode 100644 index 0000000..7851d71 --- /dev/null +++ b/components/retriever/es8/search_mode/raw_string.go @@ -0,0 +1,24 @@ +package search_mode + +import ( + "context" + + "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" + + "code.byted.org/flow/eino/components/retriever" +) + +func SearchModeRawStringRequest() SearchMode { + return &rawString{} +} + +type rawString struct{} + +func (r rawString) BuildRequest(_ context.Context, query string, _ *retriever.Options) (*search.Request, error) { + req, err := search.NewRequest().FromJSON(query) + if err != nil { + return nil, err + } + + return req, nil +} diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go new file mode 100644 index 0000000..8214442 --- /dev/null +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go @@ -0,0 +1,68 @@ +package search_mode + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + + "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" + "code.byted.org/flow/eino/components/retriever" +) + +// SearchModeSparseVectorTextExpansion convert the query text into a list of token-weight pairs, +// which are then used in a query against a sparse vector +// see: https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-text-expansion-query.html +func SearchModeSparseVectorTextExpansion(modelID string) SearchMode { + return &sparseVectorTextExpansion{modelID} +} + +type SparseVectorTextExpansionQuery struct { + FieldKV field_mapping.FieldKV `json:"field_kv"` + Filters []types.Query `json:"filters,omitempty"` +} + +// ToRetrieverQuery convert approximate query to string query +func (s *SparseVectorTextExpansionQuery) ToRetrieverQuery() (string, error) { + b, err := json.Marshal(s) + if err != nil { + return "", fmt.Errorf("[ToRetrieverQuery] convert query failed, %w", err) + } + + return string(b), nil +} + +type sparseVectorTextExpansion struct { + modelID string +} + +func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) { + var sq SparseVectorTextExpansionQuery + if err := json.Unmarshal([]byte(query), &sq); err != nil { + return nil, fmt.Errorf("[BuildRequest][SearchModeSparseVectorTextExpansion] parse query failed, %w", err) + } + + name := fmt.Sprintf("%s.tokens", sq.FieldKV.FieldNameVector) + teq := types.TextExpansionQuery{ + ModelId: s.modelID, + ModelText: sq.FieldKV.Value, + } + + q := &types.Query{ + Bool: &types.BoolQuery{ + Must: []types.Query{ + {TextExpansion: map[string]types.TextExpansionQuery{name: teq}}, + }, + Filter: sq.Filters, + }, + } + + req := &search.Request{Query: q, Size: options.TopK} + if options.ScoreThreshold != nil { + req.MinScore = (*types.Float64)(of(*options.ScoreThreshold)) + } + + return req, nil +} diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go new file mode 100644 index 0000000..536cd43 --- /dev/null +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go @@ -0,0 +1,59 @@ +package search_mode + +import ( + "context" + "encoding/json" + "testing" + + . "github.com/bytedance/mockey" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/smartystreets/goconvey/convey" + + "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" + "code.byted.org/flow/eino/components/retriever" +) + +func TestSearchModeSparseVectorTextExpansion(t *testing.T) { + PatchConvey("test SearchModeSparseVectorTextExpansion", t, func() { + PatchConvey("test ToRetrieverQuery", func() { + sq := &SparseVectorTextExpansionQuery{ + FieldKV: field_mapping.FieldKV{ + FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), + FieldName: field_mapping.DocFieldNameContent, + Value: "content", + }, + Filters: []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + }, + } + + ssq, err := sq.ToRetrieverQuery() + convey.So(err, convey.ShouldBeNil) + convey.So(ssq, convey.ShouldEqual, `{"field_kv":{"field_name_vector":"vector_eino_doc_content","field_name":"eino_doc_content","value":"content"},"filters":[{"match":{"label":{"query":"good"}}}]}`) + + }) + + PatchConvey("test BuildRequest", func() { + ctx := context.Background() + s := SearchModeSparseVectorTextExpansion("mock_model_id") + sq := &SparseVectorTextExpansionQuery{ + FieldKV: field_mapping.FieldKV{ + FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), + FieldName: field_mapping.DocFieldNameContent, + Value: "content", + }, + Filters: []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + }, + } + + query, _ := sq.ToRetrieverQuery() + req, err := s.BuildRequest(ctx, query, &retriever.Options{TopK: of(10), ScoreThreshold: of(1.1)}) + convey.So(err, convey.ShouldBeNil) + convey.So(req, convey.ShouldNotBeNil) + b, err := json.Marshal(req) + convey.So(err, convey.ShouldBeNil) + convey.So(string(b), convey.ShouldEqual, `{"min_score":1.1,"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}],"must":[{"text_expansion":{"vector_eino_doc_content.tokens":{"model_id":"mock_model_id","model_text":"content"}}}]}},"size":10}`) + }) + }) +} diff --git a/components/retriever/es8/search_mode/utils.go b/components/retriever/es8/search_mode/utils.go new file mode 100644 index 0000000..3b502bc --- /dev/null +++ b/components/retriever/es8/search_mode/utils.go @@ -0,0 +1,36 @@ +package search_mode + +import ( + "context" + + "code.byted.org/flow/eino/callbacks" + "code.byted.org/flow/eino/components" + "code.byted.org/flow/eino/components/embedding" +) + +func makeEmbeddingCtx(ctx context.Context, emb embedding.Embedder) context.Context { + runInfo := &callbacks.RunInfo{ + Component: components.ComponentOfEmbedding, + } + + if embType, ok := components.GetType(emb); ok { + runInfo.Type = embType + } + + runInfo.Name = runInfo.Type + string(runInfo.Component) + + return callbacks.SwitchRunInfo(ctx, runInfo) +} + +func f64To32(f64 []float64) []float32 { + f32 := make([]float32, len(f64)) + for i, f := range f64 { + f32[i] = float32(f) + } + + return f32 +} + +func of[T any](v T) *T { + return &v +} From 84cf6691c01e92ec0bfa0085160ae434bc913ecd Mon Sep 17 00:00:00 2001 From: lipandeng Date: Sun, 12 Jan 2025 17:26:41 +0800 Subject: [PATCH 02/11] feat: adjust import path --- .../es8/field_mapping/field_mapping.go | 4 ++-- components/indexer/es8/go.mod | 12 +++++++++--- components/indexer/es8/go.sum | 19 +++++++++++++++++-- components/indexer/es8/indexer.go | 12 ++++++------ components/indexer/es8/indexer_test.go | 8 ++++---- components/indexer/es8/utils.go | 4 ++-- .../retriever/es8/field_mapping/mapping.go | 4 ++-- components/retriever/es8/go.mod | 12 +++++++++--- components/retriever/es8/go.sum | 19 +++++++++++++++++-- components/retriever/es8/retriever.go | 14 +++++++------- .../retriever/es8/search_mode/approximate.go | 4 ++-- .../es8/search_mode/approximate_test.go | 6 +++--- .../search_mode/dense_vector_similarity.go | 4 ++-- .../dense_vector_similarity_test.go | 4 ++-- .../retriever/es8/search_mode/exact_match.go | 4 ++-- .../retriever/es8/search_mode/interface.go | 2 +- .../retriever/es8/search_mode/raw_string.go | 2 +- .../sparse_vector_text_expansion.go | 4 ++-- .../sparse_vector_text_expansion_test.go | 4 ++-- components/retriever/es8/search_mode/utils.go | 6 +++--- 20 files changed, 95 insertions(+), 53 deletions(-) diff --git a/components/indexer/es8/field_mapping/field_mapping.go b/components/indexer/es8/field_mapping/field_mapping.go index f87aa47..d57f779 100644 --- a/components/indexer/es8/field_mapping/field_mapping.go +++ b/components/indexer/es8/field_mapping/field_mapping.go @@ -3,8 +3,8 @@ package field_mapping import ( "fmt" - "code.byted.org/flow/eino-ext/components/indexer/es8/internal" - "code.byted.org/flow/eino/schema" + "github.com/cloudwego/eino-ext/components/indexer/es8/internal" + "github.com/cloudwego/eino/schema" ) // SetExtraDataFields set data fields for es diff --git a/components/indexer/es8/go.mod b/components/indexer/es8/go.mod index 2cae4e2..a009e88 100644 --- a/components/indexer/es8/go.mod +++ b/components/indexer/es8/go.mod @@ -1,14 +1,19 @@ -module code.byted.org/flow/eino-ext/components/indexer/es8 +module github.com/cloudwego/eino-ext/components/indexer/es8 go 1.22 require ( - code.byted.org/flow/eino v0.2.5 github.com/bytedance/mockey v1.2.13 + github.com/cloudwego/eino v0.3.5 github.com/elastic/go-elasticsearch/v8 v8.16.0 + github.com/smartystreets/goconvey v1.8.1 ) require ( + github.com/bytedance/sonic v1.12.2 // indirect + github.com/bytedance/sonic/loader v0.2.0 // indirect + github.com/cloudwego/base64x v0.1.4 // indirect + github.com/cloudwego/iasm v0.2.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/elastic/elastic-transport-go/v8 v8.6.0 // indirect github.com/getkin/kin-openapi v0.118.0 // indirect @@ -22,6 +27,7 @@ require ( github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect @@ -33,7 +39,7 @@ require ( github.com/sirupsen/logrus v1.9.3 // indirect github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect github.com/smarty/assertions v1.15.0 // indirect - github.com/smartystreets/goconvey v1.8.1 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/yargevad/filepathx v1.0.0 // indirect go.opentelemetry.io/otel v1.28.0 // indirect go.opentelemetry.io/otel/metric v1.28.0 // indirect diff --git a/components/indexer/es8/go.sum b/components/indexer/es8/go.sum index 5288569..e701f60 100644 --- a/components/indexer/es8/go.sum +++ b/components/indexer/es8/go.sum @@ -1,5 +1,3 @@ -code.byted.org/flow/eino v0.2.5 h1:uPXgTMSfGZZvKhY4c44vt+CPe0FlIQ7/tfFZ/AfF8nI= -code.byted.org/flow/eino v0.2.5/go.mod h1:+o4CnsT/qFrbqhRMBbi70qS7mseWv1md/SD0Jo1kKEA= github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o= github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= @@ -7,7 +5,18 @@ github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqR github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= github.com/bytedance/mockey v1.2.13 h1:jokWZAm/pUEbD939Rhznz615MKUCZNuvCFQlJ2+ntoo= github.com/bytedance/mockey v1.2.13/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY= +github.com/bytedance/sonic v1.12.2 h1:oaMFuRTpMHYLpCntGca65YWt5ny+wAceDERTkT2L9lg= +github.com/bytedance/sonic v1.12.2/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM= +github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= +github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= +github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/eino v0.3.5 h1:9PkAOX/phFifrGXkfl4L9rdecxOQJBJY1FtZqF4bz3c= +github.com/cloudwego/eino v0.3.5/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -54,6 +63,9 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -114,6 +126,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= @@ -162,3 +176,4 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= diff --git a/components/indexer/es8/indexer.go b/components/indexer/es8/indexer.go index b028d4b..5e1af3e 100644 --- a/components/indexer/es8/indexer.go +++ b/components/indexer/es8/indexer.go @@ -9,12 +9,12 @@ import ( "github.com/elastic/go-elasticsearch/v8" "github.com/elastic/go-elasticsearch/v8/esutil" - "code.byted.org/flow/eino-ext/components/indexer/es8/field_mapping" - "code.byted.org/flow/eino/callbacks" - "code.byted.org/flow/eino/components" - "code.byted.org/flow/eino/components/embedding" - "code.byted.org/flow/eino/components/indexer" - "code.byted.org/flow/eino/schema" + "github.com/cloudwego/eino-ext/components/indexer/es8/field_mapping" + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/embedding" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/schema" ) type IndexerConfig struct { diff --git a/components/indexer/es8/indexer_test.go b/components/indexer/es8/indexer_test.go index 2de06d9..daae53e 100644 --- a/components/indexer/es8/indexer_test.go +++ b/components/indexer/es8/indexer_test.go @@ -9,10 +9,10 @@ import ( . "github.com/bytedance/mockey" "github.com/smartystreets/goconvey/convey" - "code.byted.org/flow/eino-ext/components/indexer/es8/field_mapping" - "code.byted.org/flow/eino/components/embedding" - "code.byted.org/flow/eino/components/indexer" - "code.byted.org/flow/eino/schema" + "github.com/cloudwego/eino-ext/components/indexer/es8/field_mapping" + "github.com/cloudwego/eino/components/embedding" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/schema" ) func TestVectorQueryItems(t *testing.T) { diff --git a/components/indexer/es8/utils.go b/components/indexer/es8/utils.go index 9bb61da..5365eb3 100644 --- a/components/indexer/es8/utils.go +++ b/components/indexer/es8/utils.go @@ -1,8 +1,8 @@ package es8 import ( - "code.byted.org/flow/eino-ext/components/indexer/es8/field_mapping" - "code.byted.org/flow/eino/schema" + "github.com/cloudwego/eino-ext/components/indexer/es8/field_mapping" + "github.com/cloudwego/eino/schema" ) func GetType() string { diff --git a/components/retriever/es8/field_mapping/mapping.go b/components/retriever/es8/field_mapping/mapping.go index 6db7b02..2413666 100644 --- a/components/retriever/es8/field_mapping/mapping.go +++ b/components/retriever/es8/field_mapping/mapping.go @@ -3,8 +3,8 @@ package field_mapping import ( "fmt" - "code.byted.org/flow/eino-ext/components/retriever/es8/internal" - "code.byted.org/flow/eino/schema" + "github.com/cloudwego/eino-ext/components/retriever/es8/internal" + "github.com/cloudwego/eino/schema" ) // GetDefaultVectorFieldKeyContent get default es key for Document.Content diff --git a/components/retriever/es8/go.mod b/components/retriever/es8/go.mod index 77fbaca..000a30f 100644 --- a/components/retriever/es8/go.mod +++ b/components/retriever/es8/go.mod @@ -1,14 +1,19 @@ -module code.byted.org/flow/eino-ext/components/retriever/es8 +module github.com/cloudwego/eino-ext/components/retriever/es8 go 1.22 require ( - code.byted.org/flow/eino v0.2.5 github.com/bytedance/mockey v1.2.13 + github.com/cloudwego/eino v0.3.5 github.com/elastic/go-elasticsearch/v8 v8.16.0 + github.com/smartystreets/goconvey v1.8.1 ) require ( + github.com/bytedance/sonic v1.12.2 // indirect + github.com/bytedance/sonic/loader v0.2.0 // indirect + github.com/cloudwego/base64x v0.1.4 // indirect + github.com/cloudwego/iasm v0.2.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/elastic/elastic-transport-go/v8 v8.6.0 // indirect github.com/getkin/kin-openapi v0.118.0 // indirect @@ -22,6 +27,7 @@ require ( github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect @@ -33,7 +39,7 @@ require ( github.com/sirupsen/logrus v1.9.3 // indirect github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect github.com/smarty/assertions v1.15.0 // indirect - github.com/smartystreets/goconvey v1.8.1 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/yargevad/filepathx v1.0.0 // indirect go.opentelemetry.io/otel v1.28.0 // indirect go.opentelemetry.io/otel/metric v1.28.0 // indirect diff --git a/components/retriever/es8/go.sum b/components/retriever/es8/go.sum index 5288569..e701f60 100644 --- a/components/retriever/es8/go.sum +++ b/components/retriever/es8/go.sum @@ -1,5 +1,3 @@ -code.byted.org/flow/eino v0.2.5 h1:uPXgTMSfGZZvKhY4c44vt+CPe0FlIQ7/tfFZ/AfF8nI= -code.byted.org/flow/eino v0.2.5/go.mod h1:+o4CnsT/qFrbqhRMBbi70qS7mseWv1md/SD0Jo1kKEA= github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o= github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= @@ -7,7 +5,18 @@ github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqR github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= github.com/bytedance/mockey v1.2.13 h1:jokWZAm/pUEbD939Rhznz615MKUCZNuvCFQlJ2+ntoo= github.com/bytedance/mockey v1.2.13/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY= +github.com/bytedance/sonic v1.12.2 h1:oaMFuRTpMHYLpCntGca65YWt5ny+wAceDERTkT2L9lg= +github.com/bytedance/sonic v1.12.2/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM= +github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= +github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= +github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/eino v0.3.5 h1:9PkAOX/phFifrGXkfl4L9rdecxOQJBJY1FtZqF4bz3c= +github.com/cloudwego/eino v0.3.5/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -54,6 +63,9 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -114,6 +126,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= @@ -162,3 +176,4 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= diff --git a/components/retriever/es8/retriever.go b/components/retriever/es8/retriever.go index 46fe8f0..a9de5fa 100644 --- a/components/retriever/es8/retriever.go +++ b/components/retriever/es8/retriever.go @@ -8,13 +8,13 @@ import ( "github.com/elastic/go-elasticsearch/v8" "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" - "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" - "code.byted.org/flow/eino-ext/components/retriever/es8/internal" - "code.byted.org/flow/eino-ext/components/retriever/es8/search_mode" - "code.byted.org/flow/eino/callbacks" - "code.byted.org/flow/eino/components/embedding" - "code.byted.org/flow/eino/components/retriever" - "code.byted.org/flow/eino/schema" + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" + "github.com/cloudwego/eino-ext/components/retriever/es8/internal" + "github.com/cloudwego/eino-ext/components/retriever/es8/search_mode" + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components/embedding" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" ) type RetrieverConfig struct { diff --git a/components/retriever/es8/search_mode/approximate.go b/components/retriever/es8/search_mode/approximate.go index 65c7882..6055470 100644 --- a/components/retriever/es8/search_mode/approximate.go +++ b/components/retriever/es8/search_mode/approximate.go @@ -8,8 +8,8 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" - "code.byted.org/flow/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" + "github.com/cloudwego/eino/components/retriever" ) // SearchModeApproximate retrieve with multiple approximate strategy (filter+knn+rrf) diff --git a/components/retriever/es8/search_mode/approximate_test.go b/components/retriever/es8/search_mode/approximate_test.go index ffae089..1709c11 100644 --- a/components/retriever/es8/search_mode/approximate_test.go +++ b/components/retriever/es8/search_mode/approximate_test.go @@ -9,9 +9,9 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/smartystreets/goconvey/convey" - "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" - "code.byted.org/flow/eino/components/embedding" - "code.byted.org/flow/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" + "github.com/cloudwego/eino/components/embedding" + "github.com/cloudwego/eino/components/retriever" ) func TestSearchModeApproximate(t *testing.T) { diff --git a/components/retriever/es8/search_mode/dense_vector_similarity.go b/components/retriever/es8/search_mode/dense_vector_similarity.go index f78b456..718ce64 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity.go @@ -8,8 +8,8 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" - "code.byted.org/flow/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" + "github.com/cloudwego/eino/components/retriever" ) // SearchModeDenseVectorSimilarity calculate embedding similarity between dense_vector field and query diff --git a/components/retriever/es8/search_mode/dense_vector_similarity_test.go b/components/retriever/es8/search_mode/dense_vector_similarity_test.go index 4742248..2be0af3 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity_test.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity_test.go @@ -10,8 +10,8 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/smartystreets/goconvey/convey" - "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" - "code.byted.org/flow/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" + "github.com/cloudwego/eino/components/retriever" ) func TestSearchModeDenseVectorSimilarity(t *testing.T) { diff --git a/components/retriever/es8/search_mode/exact_match.go b/components/retriever/es8/search_mode/exact_match.go index 048b46f..b589f20 100644 --- a/components/retriever/es8/search_mode/exact_match.go +++ b/components/retriever/es8/search_mode/exact_match.go @@ -6,8 +6,8 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" - "code.byted.org/flow/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" + "github.com/cloudwego/eino/components/retriever" ) func SearchModeExactMatch() SearchMode { diff --git a/components/retriever/es8/search_mode/interface.go b/components/retriever/es8/search_mode/interface.go index 34ad08f..d509b76 100644 --- a/components/retriever/es8/search_mode/interface.go +++ b/components/retriever/es8/search_mode/interface.go @@ -5,7 +5,7 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" - "code.byted.org/flow/eino/components/retriever" + "github.com/cloudwego/eino/components/retriever" ) type SearchMode interface { // nolint: byted_s_interface_name diff --git a/components/retriever/es8/search_mode/raw_string.go b/components/retriever/es8/search_mode/raw_string.go index 7851d71..9bf487c 100644 --- a/components/retriever/es8/search_mode/raw_string.go +++ b/components/retriever/es8/search_mode/raw_string.go @@ -5,7 +5,7 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" - "code.byted.org/flow/eino/components/retriever" + "github.com/cloudwego/eino/components/retriever" ) func SearchModeRawStringRequest() SearchMode { diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go index 8214442..74f40c1 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go @@ -8,8 +8,8 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" - "code.byted.org/flow/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" + "github.com/cloudwego/eino/components/retriever" ) // SearchModeSparseVectorTextExpansion convert the query text into a list of token-weight pairs, diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go index 536cd43..9b858e8 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go @@ -9,8 +9,8 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/smartystreets/goconvey/convey" - "code.byted.org/flow/eino-ext/components/retriever/es8/field_mapping" - "code.byted.org/flow/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" + "github.com/cloudwego/eino/components/retriever" ) func TestSearchModeSparseVectorTextExpansion(t *testing.T) { diff --git a/components/retriever/es8/search_mode/utils.go b/components/retriever/es8/search_mode/utils.go index 3b502bc..b673539 100644 --- a/components/retriever/es8/search_mode/utils.go +++ b/components/retriever/es8/search_mode/utils.go @@ -3,9 +3,9 @@ package search_mode import ( "context" - "code.byted.org/flow/eino/callbacks" - "code.byted.org/flow/eino/components" - "code.byted.org/flow/eino/components/embedding" + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/embedding" ) func makeEmbeddingCtx(ctx context.Context, emb embedding.Embedder) context.Context { From db79a95aa0258c5af8c8437da79ff3b87fe48aaa Mon Sep 17 00:00:00 2001 From: lipandeng Date: Sun, 12 Jan 2025 17:28:49 +0800 Subject: [PATCH 03/11] feat: add license header --- components/indexer/es8/consts.go | 16 ++++++++++++++++ components/indexer/es8/field_mapping/consts.go | 16 ++++++++++++++++ .../indexer/es8/field_mapping/field_mapping.go | 16 ++++++++++++++++ components/indexer/es8/indexer.go | 16 ++++++++++++++++ components/indexer/es8/indexer_test.go | 16 ++++++++++++++++ components/indexer/es8/internal/consts.go | 16 ++++++++++++++++ components/indexer/es8/utils.go | 16 ++++++++++++++++ components/retriever/es8/consts.go | 16 ++++++++++++++++ components/retriever/es8/field_mapping/consts.go | 16 ++++++++++++++++ .../retriever/es8/field_mapping/mapping.go | 16 ++++++++++++++++ components/retriever/es8/internal/consts.go | 16 ++++++++++++++++ components/retriever/es8/retriever.go | 16 ++++++++++++++++ .../retriever/es8/search_mode/approximate.go | 16 ++++++++++++++++ .../es8/search_mode/approximate_test.go | 16 ++++++++++++++++ .../es8/search_mode/dense_vector_similarity.go | 16 ++++++++++++++++ .../search_mode/dense_vector_similarity_test.go | 16 ++++++++++++++++ .../retriever/es8/search_mode/exact_match.go | 16 ++++++++++++++++ .../retriever/es8/search_mode/interface.go | 16 ++++++++++++++++ .../retriever/es8/search_mode/raw_string.go | 16 ++++++++++++++++ .../search_mode/sparse_vector_text_expansion.go | 16 ++++++++++++++++ .../sparse_vector_text_expansion_test.go | 16 ++++++++++++++++ components/retriever/es8/search_mode/utils.go | 16 ++++++++++++++++ 22 files changed, 352 insertions(+) diff --git a/components/indexer/es8/consts.go b/components/indexer/es8/consts.go index dc0521c..de8b5bb 100644 --- a/components/indexer/es8/consts.go +++ b/components/indexer/es8/consts.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package es8 const typ = "ElasticSearch8" diff --git a/components/indexer/es8/field_mapping/consts.go b/components/indexer/es8/field_mapping/consts.go index 80438fa..abe46ff 100644 --- a/components/indexer/es8/field_mapping/consts.go +++ b/components/indexer/es8/field_mapping/consts.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package field_mapping const DocFieldNameContent = "eino_doc_content" diff --git a/components/indexer/es8/field_mapping/field_mapping.go b/components/indexer/es8/field_mapping/field_mapping.go index d57f779..ebdbc08 100644 --- a/components/indexer/es8/field_mapping/field_mapping.go +++ b/components/indexer/es8/field_mapping/field_mapping.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package field_mapping import ( diff --git a/components/indexer/es8/indexer.go b/components/indexer/es8/indexer.go index 5e1af3e..611f913 100644 --- a/components/indexer/es8/indexer.go +++ b/components/indexer/es8/indexer.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package es8 import ( diff --git a/components/indexer/es8/indexer_test.go b/components/indexer/es8/indexer_test.go index daae53e..435d798 100644 --- a/components/indexer/es8/indexer_test.go +++ b/components/indexer/es8/indexer_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package es8 import ( diff --git a/components/indexer/es8/internal/consts.go b/components/indexer/es8/internal/consts.go index e298275..c515981 100644 --- a/components/indexer/es8/internal/consts.go +++ b/components/indexer/es8/internal/consts.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package internal const ( diff --git a/components/indexer/es8/utils.go b/components/indexer/es8/utils.go index 5365eb3..9ff0188 100644 --- a/components/indexer/es8/utils.go +++ b/components/indexer/es8/utils.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package es8 import ( diff --git a/components/retriever/es8/consts.go b/components/retriever/es8/consts.go index ec1211c..3f1ffcd 100644 --- a/components/retriever/es8/consts.go +++ b/components/retriever/es8/consts.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package es8 const typ = "ElasticSearch8" diff --git a/components/retriever/es8/field_mapping/consts.go b/components/retriever/es8/field_mapping/consts.go index 80438fa..abe46ff 100644 --- a/components/retriever/es8/field_mapping/consts.go +++ b/components/retriever/es8/field_mapping/consts.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package field_mapping const DocFieldNameContent = "eino_doc_content" diff --git a/components/retriever/es8/field_mapping/mapping.go b/components/retriever/es8/field_mapping/mapping.go index 2413666..6cc35e4 100644 --- a/components/retriever/es8/field_mapping/mapping.go +++ b/components/retriever/es8/field_mapping/mapping.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package field_mapping import ( diff --git a/components/retriever/es8/internal/consts.go b/components/retriever/es8/internal/consts.go index 27aa995..bf3b10d 100644 --- a/components/retriever/es8/internal/consts.go +++ b/components/retriever/es8/internal/consts.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package internal const ( diff --git a/components/retriever/es8/retriever.go b/components/retriever/es8/retriever.go index a9de5fa..535a9bd 100644 --- a/components/retriever/es8/retriever.go +++ b/components/retriever/es8/retriever.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package es8 import ( diff --git a/components/retriever/es8/search_mode/approximate.go b/components/retriever/es8/search_mode/approximate.go index 6055470..fdbc70a 100644 --- a/components/retriever/es8/search_mode/approximate.go +++ b/components/retriever/es8/search_mode/approximate.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package search_mode import ( diff --git a/components/retriever/es8/search_mode/approximate_test.go b/components/retriever/es8/search_mode/approximate_test.go index 1709c11..ba2b5f6 100644 --- a/components/retriever/es8/search_mode/approximate_test.go +++ b/components/retriever/es8/search_mode/approximate_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package search_mode import ( diff --git a/components/retriever/es8/search_mode/dense_vector_similarity.go b/components/retriever/es8/search_mode/dense_vector_similarity.go index 718ce64..c8f4a8a 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package search_mode import ( diff --git a/components/retriever/es8/search_mode/dense_vector_similarity_test.go b/components/retriever/es8/search_mode/dense_vector_similarity_test.go index 2be0af3..6a53e0b 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity_test.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package search_mode import ( diff --git a/components/retriever/es8/search_mode/exact_match.go b/components/retriever/es8/search_mode/exact_match.go index b589f20..0aff12a 100644 --- a/components/retriever/es8/search_mode/exact_match.go +++ b/components/retriever/es8/search_mode/exact_match.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package search_mode import ( diff --git a/components/retriever/es8/search_mode/interface.go b/components/retriever/es8/search_mode/interface.go index d509b76..ab4ef5a 100644 --- a/components/retriever/es8/search_mode/interface.go +++ b/components/retriever/es8/search_mode/interface.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package search_mode import ( diff --git a/components/retriever/es8/search_mode/raw_string.go b/components/retriever/es8/search_mode/raw_string.go index 9bf487c..da6c00f 100644 --- a/components/retriever/es8/search_mode/raw_string.go +++ b/components/retriever/es8/search_mode/raw_string.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package search_mode import ( diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go index 74f40c1..3db3076 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package search_mode import ( diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go index 9b858e8..6a58b41 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package search_mode import ( diff --git a/components/retriever/es8/search_mode/utils.go b/components/retriever/es8/search_mode/utils.go index b673539..97d6f21 100644 --- a/components/retriever/es8/search_mode/utils.go +++ b/components/retriever/es8/search_mode/utils.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package search_mode import ( From 1881930d0b12e1d8bf36b97b178b6365306d9cb2 Mon Sep 17 00:00:00 2001 From: lipandeng Date: Sun, 12 Jan 2025 17:43:03 +0800 Subject: [PATCH 04/11] feat: adjust eino's API --- components/indexer/es8/field_mapping/field_mapping.go | 3 ++- components/indexer/es8/indexer.go | 11 ++++++----- components/indexer/es8/indexer_test.go | 5 +++-- components/indexer/es8/utils.go | 3 ++- components/retriever/es8/retriever.go | 7 ++++--- components/retriever/es8/search_mode/approximate.go | 3 ++- .../retriever/es8/search_mode/approximate_test.go | 3 ++- .../es8/search_mode/dense_vector_similarity.go | 3 ++- .../es8/search_mode/dense_vector_similarity_test.go | 7 ++++--- components/retriever/es8/search_mode/exact_match.go | 3 ++- .../es8/search_mode/sparse_vector_text_expansion.go | 3 ++- .../search_mode/sparse_vector_text_expansion_test.go | 3 ++- components/retriever/es8/search_mode/utils.go | 2 +- 13 files changed, 34 insertions(+), 22 deletions(-) diff --git a/components/indexer/es8/field_mapping/field_mapping.go b/components/indexer/es8/field_mapping/field_mapping.go index ebdbc08..259e7bd 100644 --- a/components/indexer/es8/field_mapping/field_mapping.go +++ b/components/indexer/es8/field_mapping/field_mapping.go @@ -19,8 +19,9 @@ package field_mapping import ( "fmt" - "github.com/cloudwego/eino-ext/components/indexer/es8/internal" "github.com/cloudwego/eino/schema" + + "github.com/cloudwego/eino-ext/components/indexer/es8/internal" ) // SetExtraDataFields set data fields for es diff --git a/components/indexer/es8/indexer.go b/components/indexer/es8/indexer.go index 611f913..a826bfe 100644 --- a/components/indexer/es8/indexer.go +++ b/components/indexer/es8/indexer.go @@ -25,12 +25,13 @@ import ( "github.com/elastic/go-elasticsearch/v8" "github.com/elastic/go-elasticsearch/v8/esutil" - "github.com/cloudwego/eino-ext/components/indexer/es8/field_mapping" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/schema" + + "github.com/cloudwego/eino-ext/components/indexer/es8/field_mapping" ) type IndexerConfig struct { @@ -161,8 +162,8 @@ func (i *Indexer) vectorQueryItems(ctx context.Context, docs []*schema.Document, return item, fmt.Errorf("[vectorQueryItems] field name not found or type incorrect, name=%s, doc=%v", kv.FieldName, doc) } - if kv.FieldName == field_mapping.DocFieldNameContent && len(doc.Vector()) > 0 { - mp[string(kv.FieldNameVector)] = doc.Vector() + if kv.FieldName == field_mapping.DocFieldNameContent && len(doc.DenseVector()) > 0 { + mp[string(kv.FieldNameVector)] = doc.DenseVector() } else { texts = append(texts, str) } @@ -184,7 +185,7 @@ func (i *Indexer) vectorQueryItems(ctx context.Context, docs []*schema.Document, vIdx := 0 for _, kv := range i.config.VectorFields { - if kv.FieldName == field_mapping.DocFieldNameContent && len(doc.Vector()) > 0 { + if kv.FieldName == field_mapping.DocFieldNameContent && len(doc.DenseVector()) > 0 { continue } @@ -224,7 +225,7 @@ func (i *Indexer) makeEmbeddingCtx(ctx context.Context, emb embedding.Embedder) runInfo.Name = runInfo.Type + string(runInfo.Component) - return callbacks.SwitchRunInfo(ctx, runInfo) + return callbacks.ReuseHandlers(ctx, runInfo) } func (i *Indexer) GetType() string { diff --git a/components/indexer/es8/indexer_test.go b/components/indexer/es8/indexer_test.go index 435d798..7f1356e 100644 --- a/components/indexer/es8/indexer_test.go +++ b/components/indexer/es8/indexer_test.go @@ -25,10 +25,11 @@ import ( . "github.com/bytedance/mockey" "github.com/smartystreets/goconvey/convey" - "github.com/cloudwego/eino-ext/components/indexer/es8/field_mapping" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/schema" + + "github.com/cloudwego/eino-ext/components/indexer/es8/field_mapping" ) func TestVectorQueryItems(t *testing.T) { @@ -37,7 +38,7 @@ func TestVectorQueryItems(t *testing.T) { extField := "extra_field" d1 := &schema.Document{ID: "123", Content: "asd"} - d1.WithVector([]float64{2.3, 4.4}) + d1.WithDenseVector([]float64{2.3, 4.4}) field_mapping.SetExtraDataFields(d1, map[string]interface{}{extField: "ext_1"}) d2 := &schema.Document{ID: "456", Content: "qwe"} diff --git a/components/indexer/es8/utils.go b/components/indexer/es8/utils.go index 9ff0188..a5f8952 100644 --- a/components/indexer/es8/utils.go +++ b/components/indexer/es8/utils.go @@ -17,8 +17,9 @@ package es8 import ( - "github.com/cloudwego/eino-ext/components/indexer/es8/field_mapping" "github.com/cloudwego/eino/schema" + + "github.com/cloudwego/eino-ext/components/indexer/es8/field_mapping" ) func GetType() string { diff --git a/components/retriever/es8/retriever.go b/components/retriever/es8/retriever.go index 535a9bd..a44d0dc 100644 --- a/components/retriever/es8/retriever.go +++ b/components/retriever/es8/retriever.go @@ -24,13 +24,14 @@ import ( "github.com/elastic/go-elasticsearch/v8" "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" - "github.com/cloudwego/eino-ext/components/retriever/es8/internal" - "github.com/cloudwego/eino-ext/components/retriever/es8/search_mode" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/schema" + + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" + "github.com/cloudwego/eino-ext/components/retriever/es8/internal" + "github.com/cloudwego/eino-ext/components/retriever/es8/search_mode" ) type RetrieverConfig struct { diff --git a/components/retriever/es8/search_mode/approximate.go b/components/retriever/es8/search_mode/approximate.go index fdbc70a..e497434 100644 --- a/components/retriever/es8/search_mode/approximate.go +++ b/components/retriever/es8/search_mode/approximate.go @@ -24,8 +24,9 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" "github.com/cloudwego/eino/components/retriever" + + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) // SearchModeApproximate retrieve with multiple approximate strategy (filter+knn+rrf) diff --git a/components/retriever/es8/search_mode/approximate_test.go b/components/retriever/es8/search_mode/approximate_test.go index ba2b5f6..05c6ea5 100644 --- a/components/retriever/es8/search_mode/approximate_test.go +++ b/components/retriever/es8/search_mode/approximate_test.go @@ -25,9 +25,10 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/smartystreets/goconvey/convey" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/retriever" + + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) func TestSearchModeApproximate(t *testing.T) { diff --git a/components/retriever/es8/search_mode/dense_vector_similarity.go b/components/retriever/es8/search_mode/dense_vector_similarity.go index c8f4a8a..460e764 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity.go @@ -24,8 +24,9 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" "github.com/cloudwego/eino/components/retriever" + + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) // SearchModeDenseVectorSimilarity calculate embedding similarity between dense_vector field and query diff --git a/components/retriever/es8/search_mode/dense_vector_similarity_test.go b/components/retriever/es8/search_mode/dense_vector_similarity_test.go index 6a53e0b..ea13b57 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity_test.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity_test.go @@ -26,8 +26,9 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/smartystreets/goconvey/convey" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" "github.com/cloudwego/eino/components/retriever" + + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) func TestSearchModeDenseVectorSimilarity(t *testing.T) { @@ -85,8 +86,8 @@ func TestSearchModeDenseVectorSimilarity(t *testing.T) { } for typ, exp := range typ2Exp { - nd := &denseVectorSimilarity{script: denseVectorScriptMap[typ]} - req, err := nd.BuildRequest(ctx, sq, &retriever.Options{ + similarity := &denseVectorSimilarity{script: denseVectorScriptMap[typ]} + req, err := similarity.BuildRequest(ctx, sq, &retriever.Options{ Embedding: mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}, TopK: of(10), ScoreThreshold: of(1.1), diff --git a/components/retriever/es8/search_mode/exact_match.go b/components/retriever/es8/search_mode/exact_match.go index 0aff12a..d8041f5 100644 --- a/components/retriever/es8/search_mode/exact_match.go +++ b/components/retriever/es8/search_mode/exact_match.go @@ -22,8 +22,9 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" "github.com/cloudwego/eino/components/retriever" + + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) func SearchModeExactMatch() SearchMode { diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go index 3db3076..ab08596 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go @@ -24,8 +24,9 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" "github.com/cloudwego/eino/components/retriever" + + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) // SearchModeSparseVectorTextExpansion convert the query text into a list of token-weight pairs, diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go index 6a58b41..eae9178 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go @@ -25,8 +25,9 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/smartystreets/goconvey/convey" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" "github.com/cloudwego/eino/components/retriever" + + "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) func TestSearchModeSparseVectorTextExpansion(t *testing.T) { diff --git a/components/retriever/es8/search_mode/utils.go b/components/retriever/es8/search_mode/utils.go index 97d6f21..7501ada 100644 --- a/components/retriever/es8/search_mode/utils.go +++ b/components/retriever/es8/search_mode/utils.go @@ -35,7 +35,7 @@ func makeEmbeddingCtx(ctx context.Context, emb embedding.Embedder) context.Conte runInfo.Name = runInfo.Type + string(runInfo.Component) - return callbacks.SwitchRunInfo(ctx, runInfo) + return callbacks.ReuseHandlers(ctx, runInfo) } func f64To32(f64 []float64) []float32 { From e6a7c990ae18306d98d3adaa1af504aefa67facb Mon Sep 17 00:00:00 2001 From: lipandeng Date: Sun, 12 Jan 2025 18:09:36 +0800 Subject: [PATCH 05/11] feat: redefine es8 retriever's SearchMode --- components/retriever/es8/retriever.go | 4 ++++ .../retriever/es8/search_mode/approximate.go | 13 +++++++++++-- .../es8/search_mode/approximate_test.go | 17 ++++++++++------- .../search_mode/dense_vector_similarity.go | 14 ++++++++++++-- .../dense_vector_similarity_test.go | 19 ++++++++++++------- .../retriever/es8/search_mode/exact_match.go | 14 ++++++++++++-- .../retriever/es8/search_mode/raw_string.go | 8 ++++++-- .../sparse_vector_text_expansion.go | 14 ++++++++++++-- .../sparse_vector_text_expansion_test.go | 8 +++++++- 9 files changed, 86 insertions(+), 25 deletions(-) diff --git a/components/retriever/es8/retriever.go b/components/retriever/es8/retriever.go index a44d0dc..82063ff 100644 --- a/components/retriever/es8/retriever.go +++ b/components/retriever/es8/retriever.go @@ -52,6 +52,10 @@ type RetrieverConfig struct { Embedding embedding.Embedder } +type SearchMode interface { // nolint: byted_s_interface_name + BuildRequest(ctx context.Context, conf *RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) +} + type Retriever struct { client *elasticsearch.TypedClient config *RetrieverConfig diff --git a/components/retriever/es8/search_mode/approximate.go b/components/retriever/es8/search_mode/approximate.go index e497434..163c332 100644 --- a/components/retriever/es8/search_mode/approximate.go +++ b/components/retriever/es8/search_mode/approximate.go @@ -26,13 +26,14 @@ import ( "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) // SearchModeApproximate retrieve with multiple approximate strategy (filter+knn+rrf) // knn: https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html // rrf: https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html -func SearchModeApproximate(config *ApproximateConfig) SearchMode { +func SearchModeApproximate(config *ApproximateConfig) es8.SearchMode { return &approximate{config} } @@ -85,7 +86,15 @@ type approximate struct { config *ApproximateConfig } -func (a *approximate) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) { +func (a *approximate) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) { + + options := retriever.GetCommonOptions(&retriever.Options{ + Index: &conf.Index, + TopK: &conf.TopK, + ScoreThreshold: conf.ScoreThreshold, + Embedding: conf.Embedding, + }, opts...) + var appReq ApproximateQuery if err := json.Unmarshal([]byte(query), &appReq); err != nil { return nil, fmt.Errorf("[BuildRequest][SearchModeApproximate] parse query failed, %w", err) diff --git a/components/retriever/es8/search_mode/approximate_test.go b/components/retriever/es8/search_mode/approximate_test.go index 05c6ea5..634a4ba 100644 --- a/components/retriever/es8/search_mode/approximate_test.go +++ b/components/retriever/es8/search_mode/approximate_test.go @@ -28,6 +28,7 @@ import ( "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) @@ -78,7 +79,8 @@ func TestSearchModeApproximate(t *testing.T) { sq, err := aq.ToRetrieverQuery() convey.So(err, convey.ShouldBeNil) - req, err := a.BuildRequest(ctx, sq, &retriever.Options{Embedding: nil}) + conf := &es8.RetrieverConfig{} + req, err := a.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(nil)) convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) convey.So(err, convey.ShouldBeNil) @@ -104,7 +106,8 @@ func TestSearchModeApproximate(t *testing.T) { sq, err := aq.ToRetrieverQuery() convey.So(err, convey.ShouldBeNil) - req, err := a.BuildRequest(ctx, sq, &retriever.Options{Embedding: &mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}}) + conf := &es8.RetrieverConfig{} + req, err := a.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}})) convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) convey.So(err, convey.ShouldBeNil) @@ -136,11 +139,11 @@ func TestSearchModeApproximate(t *testing.T) { sq, err := aq.ToRetrieverQuery() convey.So(err, convey.ShouldBeNil) - req, err := a.BuildRequest(ctx, sq, &retriever.Options{ - Embedding: &mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}, - TopK: of(10), - ScoreThreshold: of(1.1), - }) + + conf := &es8.RetrieverConfig{} + req, err := a.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}), + retriever.WithTopK(10), + retriever.WithScoreThreshold(1.1)) convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) convey.So(err, convey.ShouldBeNil) diff --git a/components/retriever/es8/search_mode/dense_vector_similarity.go b/components/retriever/es8/search_mode/dense_vector_similarity.go index 460e764..07b54d5 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity.go @@ -26,12 +26,13 @@ import ( "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) // SearchModeDenseVectorSimilarity calculate embedding similarity between dense_vector field and query // see: https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html#vector-functions -func SearchModeDenseVectorSimilarity(typ DenseVectorSimilarityType) SearchMode { +func SearchModeDenseVectorSimilarity(typ DenseVectorSimilarityType) es8.SearchMode { return &denseVectorSimilarity{script: denseVectorScriptMap[typ]} } @@ -54,7 +55,16 @@ type denseVectorSimilarity struct { script string } -func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) { +func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, + opts ...retriever.Option) (*search.Request, error) { + + options := retriever.GetCommonOptions(&retriever.Options{ + Index: &conf.Index, + TopK: &conf.TopK, + ScoreThreshold: conf.ScoreThreshold, + Embedding: conf.Embedding, + }, opts...) + var dq DenseVectorSimilarityQuery if err := json.Unmarshal([]byte(query), &dq); err != nil { return nil, fmt.Errorf("[BuildRequest][SearchModeDenseVectorSimilarity] parse query failed, %w", err) diff --git a/components/retriever/es8/search_mode/dense_vector_similarity_test.go b/components/retriever/es8/search_mode/dense_vector_similarity_test.go index ea13b57..648dd05 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity_test.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity_test.go @@ -28,6 +28,7 @@ import ( "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) @@ -66,13 +67,16 @@ func TestSearchModeDenseVectorSimilarity(t *testing.T) { sq, _ := dq.ToRetrieverQuery() PatchConvey("test embedding not provided", func() { - req, err := d.BuildRequest(ctx, sq, &retriever.Options{Embedding: nil}) + + conf := &es8.RetrieverConfig{} + req, err := d.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(nil)) convey.So(err, convey.ShouldBeError, "[BuildRequest][SearchModeDenseVectorSimilarity] embedding not provided") convey.So(req, convey.ShouldBeNil) }) PatchConvey("test vector size invalid", func() { - req, err := d.BuildRequest(ctx, sq, &retriever.Options{Embedding: mockEmbedding{size: 2, mockVector: []float64{1.1, 1.2}}}) + conf := &es8.RetrieverConfig{} + req, err := d.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(mockEmbedding{size: 2, mockVector: []float64{1.1, 1.2}})) convey.So(err, convey.ShouldBeError, "[BuildRequest][SearchModeDenseVectorSimilarity] vector size invalid, expect=1, got=2") convey.So(req, convey.ShouldBeNil) }) @@ -87,11 +91,12 @@ func TestSearchModeDenseVectorSimilarity(t *testing.T) { for typ, exp := range typ2Exp { similarity := &denseVectorSimilarity{script: denseVectorScriptMap[typ]} - req, err := similarity.BuildRequest(ctx, sq, &retriever.Options{ - Embedding: mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}, - TopK: of(10), - ScoreThreshold: of(1.1), - }) + + conf := &es8.RetrieverConfig{} + req, err := similarity.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}), + retriever.WithTopK(10), + retriever.WithScoreThreshold(1.1)) + convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) convey.So(err, convey.ShouldBeNil) diff --git a/components/retriever/es8/search_mode/exact_match.go b/components/retriever/es8/search_mode/exact_match.go index d8041f5..8f5f731 100644 --- a/components/retriever/es8/search_mode/exact_match.go +++ b/components/retriever/es8/search_mode/exact_match.go @@ -24,16 +24,26 @@ import ( "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) -func SearchModeExactMatch() SearchMode { +func SearchModeExactMatch() es8.SearchMode { return &exactMatch{} } type exactMatch struct{} -func (e exactMatch) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) { +func (e exactMatch) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, + opts ...retriever.Option) (*search.Request, error) { + + options := retriever.GetCommonOptions(&retriever.Options{ + Index: &conf.Index, + TopK: &conf.TopK, + ScoreThreshold: conf.ScoreThreshold, + Embedding: conf.Embedding, + }, opts...) + q := &types.Query{ Match: map[string]types.MatchQuery{ field_mapping.DocFieldNameContent: {Query: query}, diff --git a/components/retriever/es8/search_mode/raw_string.go b/components/retriever/es8/search_mode/raw_string.go index da6c00f..674d0a2 100644 --- a/components/retriever/es8/search_mode/raw_string.go +++ b/components/retriever/es8/search_mode/raw_string.go @@ -22,15 +22,19 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/cloudwego/eino/components/retriever" + + "github.com/cloudwego/eino-ext/components/retriever/es8" ) -func SearchModeRawStringRequest() SearchMode { +func SearchModeRawStringRequest() es8.SearchMode { return &rawString{} } type rawString struct{} -func (r rawString) BuildRequest(_ context.Context, query string, _ *retriever.Options) (*search.Request, error) { +func (r rawString) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, + opts ...retriever.Option) (*search.Request, error) { + req, err := search.NewRequest().FromJSON(query) if err != nil { return nil, err diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go index ab08596..7476bb0 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go @@ -26,13 +26,14 @@ import ( "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) // SearchModeSparseVectorTextExpansion convert the query text into a list of token-weight pairs, // which are then used in a query against a sparse vector // see: https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-text-expansion-query.html -func SearchModeSparseVectorTextExpansion(modelID string) SearchMode { +func SearchModeSparseVectorTextExpansion(modelID string) es8.SearchMode { return &sparseVectorTextExpansion{modelID} } @@ -55,7 +56,16 @@ type sparseVectorTextExpansion struct { modelID string } -func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) { +func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, + opts ...retriever.Option) (*search.Request, error) { + + options := retriever.GetCommonOptions(&retriever.Options{ + Index: &conf.Index, + TopK: &conf.TopK, + ScoreThreshold: conf.ScoreThreshold, + Embedding: conf.Embedding, + }, opts...) + var sq SparseVectorTextExpansionQuery if err := json.Unmarshal([]byte(query), &sq); err != nil { return nil, fmt.Errorf("[BuildRequest][SearchModeSparseVectorTextExpansion] parse query failed, %w", err) diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go index eae9178..e59ba2e 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go @@ -27,6 +27,7 @@ import ( "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) @@ -65,7 +66,12 @@ func TestSearchModeSparseVectorTextExpansion(t *testing.T) { } query, _ := sq.ToRetrieverQuery() - req, err := s.BuildRequest(ctx, query, &retriever.Options{TopK: of(10), ScoreThreshold: of(1.1)}) + + conf := &es8.RetrieverConfig{} + req, err := s.BuildRequest(ctx, conf, query, + retriever.WithTopK(10), + retriever.WithScoreThreshold(1.1)) + convey.So(err, convey.ShouldBeNil) convey.So(req, convey.ShouldNotBeNil) b, err := json.Marshal(req) From b9aee471251af17ed5bd305be4c0fb9fc41a0fa2 Mon Sep 17 00:00:00 2001 From: lipandeng Date: Sun, 12 Jan 2025 20:17:23 +0800 Subject: [PATCH 06/11] feat: to nil ptr when zero value --- components/retriever/es8/retriever.go | 7 ++-- .../retriever/es8/search_mode/approximate.go | 18 ++++----- .../es8/search_mode/approximate_test.go | 40 +++++++++---------- .../search_mode/dense_vector_similarity.go | 10 ++--- .../dense_vector_similarity_test.go | 2 +- .../retriever/es8/search_mode/exact_match.go | 8 ++-- .../retriever/es8/search_mode/interface.go | 29 -------------- .../retriever/es8/search_mode/raw_string.go | 2 +- .../sparse_vector_text_expansion.go | 10 ++--- .../sparse_vector_text_expansion_test.go | 2 +- components/retriever/es8/search_mode/utils.go | 8 +++- 11 files changed, 55 insertions(+), 81 deletions(-) delete mode 100644 components/retriever/es8/search_mode/interface.go diff --git a/components/retriever/es8/retriever.go b/components/retriever/es8/retriever.go index 82063ff..b441f08 100644 --- a/components/retriever/es8/retriever.go +++ b/components/retriever/es8/retriever.go @@ -31,7 +31,6 @@ import ( "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" "github.com/cloudwego/eino-ext/components/retriever/es8/internal" - "github.com/cloudwego/eino-ext/components/retriever/es8/search_mode" ) type RetrieverConfig struct { @@ -47,12 +46,12 @@ type RetrieverConfig struct { // use search_mode.SearchModeDenseVectorSimilarity with search_mode.DenseVectorSimilarityQuery // use search_mode.SearchModeSparseVectorTextExpansion with search_mode.SparseVectorTextExpansionQuery // use search_mode.SearchModeRawStringRequest with json search request - SearchMode search_mode.SearchMode `json:"search_mode"` + SearchMode SearchMode `json:"search_mode"` // Embedding vectorization method, must provide when SearchMode needed Embedding embedding.Embedder } -type SearchMode interface { // nolint: byted_s_interface_name +type SearchMode interface { BuildRequest(ctx context.Context, conf *RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) } @@ -97,7 +96,7 @@ func (r *Retriever) Retrieve(ctx context.Context, query string, opts ...retrieve ScoreThreshold: options.ScoreThreshold, }) - req, err := r.config.SearchMode.BuildRequest(ctx, query, options) + req, err := r.config.SearchMode.BuildRequest(ctx, r.config, query, opts...) if err != nil { return nil, err } diff --git a/components/retriever/es8/search_mode/approximate.go b/components/retriever/es8/search_mode/approximate.go index 163c332..0474b55 100644 --- a/components/retriever/es8/search_mode/approximate.go +++ b/components/retriever/es8/search_mode/approximate.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -46,7 +46,7 @@ type ApproximateConfig struct { // RrfRankConstant determines how much influence documents in // individual result sets per query have over the final ranked result set RrfRankConstant *int64 - // RrfWindowSize determines the size of the individual result sets per query + // RrfWindowSize determines the size ptrWithoutZero the individual result sets per query RrfWindowSize *int64 } @@ -57,16 +57,16 @@ type ApproximateQuery struct { // QueryVectorBuilderModelID the query vector builder model id // see: https://www.elastic.co/guide/en/machine-learning/8.16/ml-nlp-text-emb-vector-search-example.html QueryVectorBuilderModelID *string `json:"query_vector_builder_model_id,omitempty"` - // Boost Floating point number used to decrease or increase the relevance scores of the query. - // Boost values are relative to the default value of 1.0. + // Boost Floating point number used to decrease or increase the relevance scores ptrWithoutZero the query. + // Boost values are relative to the default value ptrWithoutZero 1.0. // A boost value between 0 and 1.0 decreases the relevance score. // A value greater than 1.0 increases the relevance score. Boost *float32 `json:"boost,omitempty"` // Filters for the kNN search query Filters []types.Query `json:"filters,omitempty"` - // K The final number of nearest neighbors to return as top hits + // K The final number ptrWithoutZero nearest neighbors to return as top hits K *int `json:"k,omitempty"` - // NumCandidates The number of nearest neighbor candidates to consider per shard + // NumCandidates The number ptrWithoutZero nearest neighbor candidates to consider per shard NumCandidates *int `json:"num_candidates,omitempty"` // Similarity The minimum similarity for a vector to be considered a match Similarity *float32 `json:"similarity,omitempty"` @@ -89,8 +89,8 @@ type approximate struct { func (a *approximate) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) { options := retriever.GetCommonOptions(&retriever.Options{ - Index: &conf.Index, - TopK: &conf.TopK, + Index: ptrWithoutZero(conf.Index), + TopK: ptrWithoutZero(conf.TopK), ScoreThreshold: conf.ScoreThreshold, Embedding: conf.Embedding, }, opts...) @@ -159,7 +159,7 @@ func (a *approximate) BuildRequest(ctx context.Context, conf *es8.RetrieverConfi } if options.ScoreThreshold != nil { - req.MinScore = (*types.Float64)(of(*options.ScoreThreshold)) + req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold)) } return req, nil diff --git a/components/retriever/es8/search_mode/approximate_test.go b/components/retriever/es8/search_mode/approximate_test.go index 634a4ba..4d79280 100644 --- a/components/retriever/es8/search_mode/approximate_test.go +++ b/components/retriever/es8/search_mode/approximate_test.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -44,10 +44,10 @@ func TestSearchModeApproximate(t *testing.T) { Filters: []types.Query{ {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, }, - Boost: of(float32(1.0)), - K: of(10), - NumCandidates: of(100), - Similarity: of(float32(0.5)), + Boost: ptrWithoutZero(float32(1.0)), + K: ptrWithoutZero(10), + NumCandidates: ptrWithoutZero(100), + Similarity: ptrWithoutZero(float32(0.5)), } sq, err := aq.ToRetrieverQuery() @@ -66,14 +66,14 @@ func TestSearchModeApproximate(t *testing.T) { FieldName: field_mapping.DocFieldNameContent, Value: "content", }, - QueryVectorBuilderModelID: of("mock_model"), + QueryVectorBuilderModelID: ptrWithoutZero("mock_model"), Filters: []types.Query{ {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, }, - Boost: of(float32(1.0)), - K: of(10), - NumCandidates: of(100), - Similarity: of(float32(0.5)), + Boost: ptrWithoutZero(float32(1.0)), + K: ptrWithoutZero(10), + NumCandidates: ptrWithoutZero(100), + Similarity: ptrWithoutZero(float32(0.5)), } sq, err := aq.ToRetrieverQuery() @@ -98,10 +98,10 @@ func TestSearchModeApproximate(t *testing.T) { Filters: []types.Query{ {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, }, - Boost: of(float32(1.0)), - K: of(10), - NumCandidates: of(100), - Similarity: of(float32(0.5)), + Boost: ptrWithoutZero(float32(1.0)), + K: ptrWithoutZero(10), + NumCandidates: ptrWithoutZero(100), + Similarity: ptrWithoutZero(float32(0.5)), } sq, err := aq.ToRetrieverQuery() @@ -118,8 +118,8 @@ func TestSearchModeApproximate(t *testing.T) { a := &approximate{config: &ApproximateConfig{ Hybrid: true, Rrf: true, - RrfRankConstant: of(int64(10)), - RrfWindowSize: of(int64(5)), + RrfRankConstant: ptrWithoutZero(int64(10)), + RrfWindowSize: ptrWithoutZero(int64(5)), }} aq := &ApproximateQuery{ @@ -131,10 +131,10 @@ func TestSearchModeApproximate(t *testing.T) { Filters: []types.Query{ {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, }, - Boost: of(float32(1.0)), - K: of(10), - NumCandidates: of(100), - Similarity: of(float32(0.5)), + Boost: ptrWithoutZero(float32(1.0)), + K: ptrWithoutZero(10), + NumCandidates: ptrWithoutZero(100), + Similarity: ptrWithoutZero(float32(0.5)), } sq, err := aq.ToRetrieverQuery() diff --git a/components/retriever/es8/search_mode/dense_vector_similarity.go b/components/retriever/es8/search_mode/dense_vector_similarity.go index 07b54d5..40c4333 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -59,8 +59,8 @@ func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.Retr opts ...retriever.Option) (*search.Request, error) { options := retriever.GetCommonOptions(&retriever.Options{ - Index: &conf.Index, - TopK: &conf.TopK, + Index: ptrWithoutZero(conf.Index), + TopK: ptrWithoutZero(conf.TopK), ScoreThreshold: conf.ScoreThreshold, Embedding: conf.Embedding, }, opts...) @@ -92,7 +92,7 @@ func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.Retr q := &types.Query{ ScriptScore: &types.ScriptScoreQuery{ Script: types.Script{ - Source: of(fmt.Sprintf(d.script, dq.FieldKV.FieldNameVector)), + Source: ptrWithoutZero(fmt.Sprintf(d.script, dq.FieldKV.FieldNameVector)), Params: map[string]json.RawMessage{"embedding": vb}, }, }, @@ -110,7 +110,7 @@ func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.Retr req := &search.Request{Query: q, Size: options.TopK} if options.ScoreThreshold != nil { - req.MinScore = (*types.Float64)(of(*options.ScoreThreshold)) + req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold)) } return req, nil diff --git a/components/retriever/es8/search_mode/dense_vector_similarity_test.go b/components/retriever/es8/search_mode/dense_vector_similarity_test.go index 648dd05..72af92b 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity_test.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity_test.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * diff --git a/components/retriever/es8/search_mode/exact_match.go b/components/retriever/es8/search_mode/exact_match.go index 8f5f731..20e178c 100644 --- a/components/retriever/es8/search_mode/exact_match.go +++ b/components/retriever/es8/search_mode/exact_match.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -38,8 +38,8 @@ func (e exactMatch) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, opts ...retriever.Option) (*search.Request, error) { options := retriever.GetCommonOptions(&retriever.Options{ - Index: &conf.Index, - TopK: &conf.TopK, + Index: ptrWithoutZero(conf.Index), + TopK: ptrWithoutZero(conf.TopK), ScoreThreshold: conf.ScoreThreshold, Embedding: conf.Embedding, }, opts...) @@ -52,7 +52,7 @@ func (e exactMatch) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, req := &search.Request{Query: q, Size: options.TopK} if options.ScoreThreshold != nil { - req.MinScore = (*types.Float64)(of(*options.ScoreThreshold)) + req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold)) } return req, nil diff --git a/components/retriever/es8/search_mode/interface.go b/components/retriever/es8/search_mode/interface.go deleted file mode 100644 index ab4ef5a..0000000 --- a/components/retriever/es8/search_mode/interface.go +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package search_mode - -import ( - "context" - - "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" - - "github.com/cloudwego/eino/components/retriever" -) - -type SearchMode interface { // nolint: byted_s_interface_name - BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) -} diff --git a/components/retriever/es8/search_mode/raw_string.go b/components/retriever/es8/search_mode/raw_string.go index 674d0a2..855840c 100644 --- a/components/retriever/es8/search_mode/raw_string.go +++ b/components/retriever/es8/search_mode/raw_string.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go index 7476bb0..f4c2c42 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -30,7 +30,7 @@ import ( "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) -// SearchModeSparseVectorTextExpansion convert the query text into a list of token-weight pairs, +// SearchModeSparseVectorTextExpansion convert the query text into a list ptrWithoutZero token-weight pairs, // which are then used in a query against a sparse vector // see: https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-text-expansion-query.html func SearchModeSparseVectorTextExpansion(modelID string) es8.SearchMode { @@ -60,8 +60,8 @@ func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, conf *es8.R opts ...retriever.Option) (*search.Request, error) { options := retriever.GetCommonOptions(&retriever.Options{ - Index: &conf.Index, - TopK: &conf.TopK, + Index: ptrWithoutZero(conf.Index), + TopK: ptrWithoutZero(conf.TopK), ScoreThreshold: conf.ScoreThreshold, Embedding: conf.Embedding, }, opts...) @@ -88,7 +88,7 @@ func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, conf *es8.R req := &search.Request{Query: q, Size: options.TopK} if options.ScoreThreshold != nil { - req.MinScore = (*types.Float64)(of(*options.ScoreThreshold)) + req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold)) } return req, nil diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go index e59ba2e..b3f5826 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * diff --git a/components/retriever/es8/search_mode/utils.go b/components/retriever/es8/search_mode/utils.go index 7501ada..b69ed15 100644 --- a/components/retriever/es8/search_mode/utils.go +++ b/components/retriever/es8/search_mode/utils.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -47,6 +47,10 @@ func f64To32(f64 []float64) []float32 { return f32 } -func of[T any](v T) *T { +func ptrWithoutZero[T string | int64 | int | float64 | float32](v T) *T { + var zero T + if zero == v { + return nil + } return &v } From 69f1b8d6e622a3102b643fa54b3cafb57dbc1c24 Mon Sep 17 00:00:00 2001 From: lipandeng Date: Sun, 12 Jan 2025 20:56:06 +0800 Subject: [PATCH 07/11] feat: add ut for retriever --- components/retriever/es8/go.mod | 3 + components/retriever/es8/retriever_test.go | 75 ++++++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 components/retriever/es8/retriever_test.go diff --git a/components/retriever/es8/go.mod b/components/retriever/es8/go.mod index 000a30f..5d6f722 100644 --- a/components/retriever/es8/go.mod +++ b/components/retriever/es8/go.mod @@ -7,6 +7,7 @@ require ( github.com/cloudwego/eino v0.3.5 github.com/elastic/go-elasticsearch/v8 v8.16.0 github.com/smartystreets/goconvey v1.8.1 + github.com/stretchr/testify v1.9.0 ) require ( @@ -14,6 +15,7 @@ require ( github.com/bytedance/sonic/loader v0.2.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/elastic/elastic-transport-go/v8 v8.6.0 // indirect github.com/getkin/kin-openapi v0.118.0 // indirect @@ -36,6 +38,7 @@ require ( github.com/pelletier/go-toml/v2 v2.0.9 // indirect github.com/perimeterx/marshmallow v1.1.4 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect github.com/smarty/assertions v1.15.0 // indirect diff --git a/components/retriever/es8/retriever_test.go b/components/retriever/es8/retriever_test.go new file mode 100644 index 0000000..f7b2e60 --- /dev/null +++ b/components/retriever/es8/retriever_test.go @@ -0,0 +1,75 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package es8 + +import ( + "context" + "encoding/json" + "testing" + + "github.com/bytedance/mockey" + "github.com/cloudwego/eino/components/retriever" + "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/stretchr/testify/assert" +) + +func TestNewRetriever(t *testing.T) { + ctx := context.Background() + + t.Run("retrieve_documents", func(t *testing.T) { + r, err := NewRetriever(ctx, &RetrieverConfig{ + ESConfig: elasticsearch.Config{}, + Index: "eino_ut", + TopK: 10, + SearchMode: &mockSearchMode{}, + }) + assert.NoError(t, err) + + defer mockey.Mock(mockey.GetMethod(r.client.Search(), "Index")). + Return(r.client.Search()).Build().Patch().UnPatch() + + defer mockey.Mock(mockey.GetMethod(r.client.Search(), "Request")). + Return(r.client.Search()).Build().Patch().UnPatch() + + defer mockey.Mock(mockey.GetMethod(r.client.Search(), "Do")).Return(&search.Response{ + Hits: types.HitsMetadata{ + Hits: []types.Hit{ + { + Source_: json.RawMessage([]byte(`{ + "eino_doc_content": "i'm fine, thank you" +}`)), + }, + }, + }, + }, nil).Build().Patch().UnPatch() + + docs, err := r.Retrieve(ctx, "how are you") + assert.NoError(t, err) + + assert.Len(t, docs, 1) + assert.Equal(t, "i'm fine, thank you", docs[0].Content) + }) + +} + +type mockSearchMode struct{} + +func (m *mockSearchMode) BuildRequest(ctx context.Context, conf *RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) { + return &search.Request{}, nil +} From 25badfbc88b58138752536813efcdfe431257236 Mon Sep 17 00:00:00 2001 From: lipandeng Date: Sun, 12 Jan 2025 21:03:41 +0800 Subject: [PATCH 08/11] feat: license header --- components/retriever/es8/search_mode/approximate.go | 2 +- components/retriever/es8/search_mode/approximate_test.go | 2 +- components/retriever/es8/search_mode/dense_vector_similarity.go | 2 +- .../retriever/es8/search_mode/dense_vector_similarity_test.go | 2 +- components/retriever/es8/search_mode/exact_match.go | 2 +- components/retriever/es8/search_mode/raw_string.go | 2 +- .../retriever/es8/search_mode/sparse_vector_text_expansion.go | 2 +- .../es8/search_mode/sparse_vector_text_expansion_test.go | 2 +- components/retriever/es8/search_mode/utils.go | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/components/retriever/es8/search_mode/approximate.go b/components/retriever/es8/search_mode/approximate.go index 0474b55..ee1825d 100644 --- a/components/retriever/es8/search_mode/approximate.go +++ b/components/retriever/es8/search_mode/approximate.go @@ -1,5 +1,5 @@ /* - * Copyright 2024 CloudWeGo Authors + * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/components/retriever/es8/search_mode/approximate_test.go b/components/retriever/es8/search_mode/approximate_test.go index 4d79280..36e3afb 100644 --- a/components/retriever/es8/search_mode/approximate_test.go +++ b/components/retriever/es8/search_mode/approximate_test.go @@ -1,5 +1,5 @@ /* - * Copyright 2024 CloudWeGo Authors + * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/components/retriever/es8/search_mode/dense_vector_similarity.go b/components/retriever/es8/search_mode/dense_vector_similarity.go index 40c4333..18e58dd 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity.go @@ -1,5 +1,5 @@ /* - * Copyright 2024 CloudWeGo Authors + * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/components/retriever/es8/search_mode/dense_vector_similarity_test.go b/components/retriever/es8/search_mode/dense_vector_similarity_test.go index 72af92b..a68c333 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity_test.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity_test.go @@ -1,5 +1,5 @@ /* - * Copyright 2024 CloudWeGo Authors + * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/components/retriever/es8/search_mode/exact_match.go b/components/retriever/es8/search_mode/exact_match.go index 20e178c..81c7324 100644 --- a/components/retriever/es8/search_mode/exact_match.go +++ b/components/retriever/es8/search_mode/exact_match.go @@ -1,5 +1,5 @@ /* - * Copyright 2024 CloudWeGo Authors + * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/components/retriever/es8/search_mode/raw_string.go b/components/retriever/es8/search_mode/raw_string.go index 855840c..82a7aa4 100644 --- a/components/retriever/es8/search_mode/raw_string.go +++ b/components/retriever/es8/search_mode/raw_string.go @@ -1,5 +1,5 @@ /* - * Copyright 2024 CloudWeGo Authors + * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go index f4c2c42..bfa75be 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go @@ -1,5 +1,5 @@ /* - * Copyright 2024 CloudWeGo Authors + * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go index b3f5826..f4c42a9 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go @@ -1,5 +1,5 @@ /* - * Copyright 2024 CloudWeGo Authors + * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/components/retriever/es8/search_mode/utils.go b/components/retriever/es8/search_mode/utils.go index b69ed15..cc54479 100644 --- a/components/retriever/es8/search_mode/utils.go +++ b/components/retriever/es8/search_mode/utils.go @@ -1,5 +1,5 @@ /* - * Copyright 2024 CloudWeGo Authors + * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From da0e4494b4d6003379434640c3b70f9b75aaf459 Mon Sep 17 00:00:00 2001 From: xuzhaonan Date: Mon, 13 Jan 2025 18:18:53 +0800 Subject: [PATCH 09/11] refactor: refactor field mapping and search mode --- .../indexer/es8/field_mapping/consts.go | 19 --- .../es8/field_mapping/field_mapping.go | 81 --------- components/indexer/es8/go.mod | 2 +- components/indexer/es8/go.sum | 4 +- components/indexer/es8/indexer.go | 92 +++-------- components/indexer/es8/indexer_test.go | 77 ++++----- components/indexer/es8/internal/consts.go | 21 --- components/indexer/es8/utils.go | 20 +-- components/retriever/es8/consts.go | 4 + .../retriever/es8/field_mapping/consts.go | 19 --- .../retriever/es8/field_mapping/mapping.go | 58 ------- components/retriever/es8/go.mod | 2 +- components/retriever/es8/go.sum | 2 + .../es8/{internal/consts.go => option.go} | 13 +- components/retriever/es8/retriever.go | 52 ++---- components/retriever/es8/retriever_test.go | 30 +++- .../retriever/es8/search_mode/approximate.go | 99 +++++------ .../es8/search_mode/approximate_test.go | 154 ++++++++---------- .../search_mode/dense_vector_similarity.go | 59 +++---- .../dense_vector_similarity_test.go | 59 ++----- .../retriever/es8/search_mode/exact_match.go | 19 +-- .../es8/search_mode/exact_match_test.go | 41 +++++ .../retriever/es8/search_mode/raw_string.go | 8 +- .../es8/search_mode/raw_string_test.go | 48 ++++++ .../sparse_vector_text_expansion.go | 51 ++---- .../sparse_vector_text_expansion_test.go | 51 ++---- components/retriever/es8/search_mode/utils.go | 2 +- 27 files changed, 398 insertions(+), 689 deletions(-) delete mode 100644 components/indexer/es8/field_mapping/consts.go delete mode 100644 components/indexer/es8/field_mapping/field_mapping.go delete mode 100644 components/indexer/es8/internal/consts.go delete mode 100644 components/retriever/es8/field_mapping/consts.go delete mode 100644 components/retriever/es8/field_mapping/mapping.go rename components/retriever/es8/{internal/consts.go => option.go} (67%) create mode 100644 components/retriever/es8/search_mode/exact_match_test.go create mode 100644 components/retriever/es8/search_mode/raw_string_test.go diff --git a/components/indexer/es8/field_mapping/consts.go b/components/indexer/es8/field_mapping/consts.go deleted file mode 100644 index abe46ff..0000000 --- a/components/indexer/es8/field_mapping/consts.go +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package field_mapping - -const DocFieldNameContent = "eino_doc_content" diff --git a/components/indexer/es8/field_mapping/field_mapping.go b/components/indexer/es8/field_mapping/field_mapping.go deleted file mode 100644 index 259e7bd..0000000 --- a/components/indexer/es8/field_mapping/field_mapping.go +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package field_mapping - -import ( - "fmt" - - "github.com/cloudwego/eino/schema" - - "github.com/cloudwego/eino-ext/components/indexer/es8/internal" -) - -// SetExtraDataFields set data fields for es -func SetExtraDataFields(doc *schema.Document, fields map[string]interface{}) { - if doc == nil { - return - } - - if doc.MetaData == nil { - doc.MetaData = make(map[string]any) - } - - doc.MetaData[internal.DocExtraKeyEsFields] = fields -} - -// GetExtraDataFields get data fields from *schema.Document -func GetExtraDataFields(doc *schema.Document) (fields map[string]interface{}, ok bool) { - if doc == nil || doc.MetaData == nil { - return nil, false - } - - fields, ok = doc.MetaData[internal.DocExtraKeyEsFields].(map[string]interface{}) - - return fields, ok -} - -// DefaultFieldKV build default names by fieldName -// docFieldName should be DocFieldNameContent or key got from GetExtraDataFields -func DefaultFieldKV(docFieldName FieldName) FieldKV { - return FieldKV{ - FieldNameVector: FieldName(fmt.Sprintf("vector_%s", docFieldName)), - FieldName: docFieldName, - } -} - -type FieldKV struct { - // FieldNameVector vector field name (if needed) - FieldNameVector FieldName `json:"field_name_vector,omitempty"` - // FieldName field name - FieldName FieldName `json:"field_name,omitempty"` -} - -type FieldName string - -func (v FieldName) Find(doc *schema.Document) (string, bool) { - if v == DocFieldNameContent { - return doc.Content, true - } - - kvs, ok := GetExtraDataFields(doc) - if !ok { - return "", false - } - - s, ok := kvs[string(v)].(string) - return s, ok -} diff --git a/components/indexer/es8/go.mod b/components/indexer/es8/go.mod index a009e88..a9547e7 100644 --- a/components/indexer/es8/go.mod +++ b/components/indexer/es8/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bytedance/mockey v1.2.13 - github.com/cloudwego/eino v0.3.5 + github.com/cloudwego/eino v0.3.6 github.com/elastic/go-elasticsearch/v8 v8.16.0 github.com/smartystreets/goconvey v1.8.1 ) diff --git a/components/indexer/es8/go.sum b/components/indexer/es8/go.sum index e701f60..889b85b 100644 --- a/components/indexer/es8/go.sum +++ b/components/indexer/es8/go.sum @@ -13,8 +13,8 @@ github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4 github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= -github.com/cloudwego/eino v0.3.5 h1:9PkAOX/phFifrGXkfl4L9rdecxOQJBJY1FtZqF4bz3c= -github.com/cloudwego/eino v0.3.5/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo= +github.com/cloudwego/eino v0.3.6 h1:3yfdKKxMVWefdOyGXHuqUMM5cc9iioijj2mpPsDZKIg= +github.com/cloudwego/eino v0.3.6/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/components/indexer/es8/indexer.go b/components/indexer/es8/indexer.go index a826bfe..ddfebb0 100644 --- a/components/indexer/es8/indexer.go +++ b/components/indexer/es8/indexer.go @@ -30,8 +30,6 @@ import ( "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/schema" - - "github.com/cloudwego/eino-ext/components/indexer/es8/field_mapping" ) type IndexerConfig struct { @@ -39,8 +37,10 @@ type IndexerConfig struct { Index string `json:"index"` BatchSize int `json:"batch_size"` - // VectorFields dense_vector field mappings - VectorFields []field_mapping.FieldKV `json:"vector_fields"` + // FieldMapping supports customize es fields from eino document, returns: + // needEmbeddingFields will be embedded by Embedding firstly, then join fields with its keys, + // and joined fields will be saved as bulk item. + FieldMapping func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) // Embedding vectorization method, must provide in two cases // 1. VectorFields contains fields except doc Content // 2. VectorFields contains doc Content and vector not provided in doc extra (see Document.Vector method) @@ -58,13 +58,8 @@ func NewIndexer(_ context.Context, conf *IndexerConfig) (*Indexer, error) { return nil, fmt.Errorf("[NewIndexer] new es client failed, %w", err) } - if conf.Embedding == nil { - for _, kv := range conf.VectorFields { - if kv.FieldName != field_mapping.DocFieldNameContent { - return nil, fmt.Errorf("[NewIndexer] Embedding not provided in config, but field kv[%s]-[%s] requires", - kv.FieldNameVector, kv.FieldName) - } - } + if conf.FieldMapping == nil { + return nil, fmt.Errorf("[NewIndexer] field mapping method not provided") } if conf.BatchSize == 0 { @@ -99,13 +94,7 @@ func (i *Indexer) Store(ctx context.Context, docs []*schema.Document, opts ...in } for _, slice := range chunk(docs, i.config.BatchSize) { - var items []esutil.BulkIndexerItem - - if len(i.config.VectorFields) == 0 { - items, err = i.defaultQueryItems(ctx, slice, options) - } else { - items, err = i.vectorQueryItems(ctx, slice, options) - } + items, err := i.makeBulkItems(ctx, slice, options) if err != nil { return nil, err } @@ -128,73 +117,42 @@ func (i *Indexer) Store(ctx context.Context, docs []*schema.Document, opts ...in return ids, nil } -func (i *Indexer) defaultQueryItems(_ context.Context, docs []*schema.Document, _ *indexer.Options) (items []esutil.BulkIndexerItem, err error) { - items, err = iterWithErr(docs, func(doc *schema.Document) (item esutil.BulkIndexerItem, err error) { - b, err := json.Marshal(toESDoc(doc)) - if err != nil { - return item, err - } - - return esutil.BulkIndexerItem{ - Index: i.config.Index, - Action: "index", - DocumentID: doc.ID, - Body: bytes.NewReader(b), - }, nil - }) - - if err != nil { - return nil, err - } - - return items, nil -} - -func (i *Indexer) vectorQueryItems(ctx context.Context, docs []*schema.Document, options *indexer.Options) (items []esutil.BulkIndexerItem, err error) { +func (i *Indexer) makeBulkItems(ctx context.Context, docs []*schema.Document, options *indexer.Options) (items []esutil.BulkIndexerItem, err error) { emb := options.Embedding items, err = iterWithErr(docs, func(doc *schema.Document) (item esutil.BulkIndexerItem, err error) { - mp := toESDoc(doc) - texts := make([]string, 0, len(i.config.VectorFields)) - for _, kv := range i.config.VectorFields { - str, ok := kv.FieldName.Find(doc) - if !ok { - return item, fmt.Errorf("[vectorQueryItems] field name not found or type incorrect, name=%s, doc=%v", kv.FieldName, doc) - } - - if kv.FieldName == field_mapping.DocFieldNameContent && len(doc.DenseVector()) > 0 { - mp[string(kv.FieldNameVector)] = doc.DenseVector() - } else { - texts = append(texts, str) - } + fields, needEmbeddingFields, err := i.config.FieldMapping(ctx, doc) + if err != nil { + return item, fmt.Errorf("[makeBulkItems] FieldMapping failed, %w", err) } - if len(texts) > 0 { + if len(needEmbeddingFields) > 0 { if emb == nil { - return item, fmt.Errorf("[vectorQueryItems] embedding not provided") + return item, fmt.Errorf("[makeBulkItems] embedding method not provided") + } + + tuples := make([]tuple[string, int], 0, len(fields)) + texts := make([]string, 0, len(fields)) + for k, text := range needEmbeddingFields { + tuples = append(tuples, tuple[string, int]{k, len(texts)}) + texts = append(texts, text) } vectors, err := emb.EmbedStrings(i.makeEmbeddingCtx(ctx, emb), texts) if err != nil { - return item, fmt.Errorf("[vectorQueryItems] embedding failed, %w", err) + return item, fmt.Errorf("[makeBulkItems] embedding failed, %w", err) } if len(vectors) != len(texts) { - return item, fmt.Errorf("[vectorQueryItems] invalid vector length, expected=%d, got=%d", len(texts), len(vectors)) + return item, fmt.Errorf("[makeBulkItems] invalid vector length, expected=%d, got=%d", len(texts), len(vectors)) } - vIdx := 0 - for _, kv := range i.config.VectorFields { - if kv.FieldName == field_mapping.DocFieldNameContent && len(doc.DenseVector()) > 0 { - continue - } - - mp[string(kv.FieldNameVector)] = vectors[vIdx] - vIdx++ + for _, t := range tuples { + fields[t.A] = vectors[t.B] } } - b, err := json.Marshal(mp) + b, err := json.Marshal(fields) if err != nil { return item, err } diff --git a/components/indexer/es8/indexer_test.go b/components/indexer/es8/indexer_test.go index 7f1356e..a1535a4 100644 --- a/components/indexer/es8/indexer_test.go +++ b/components/indexer/es8/indexer_test.go @@ -28,99 +28,85 @@ import ( "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/schema" - - "github.com/cloudwego/eino-ext/components/indexer/es8/field_mapping" ) func TestVectorQueryItems(t *testing.T) { - PatchConvey("test vectorQueryItems", t, func() { + PatchConvey("test makeBulkItems", t, func() { ctx := context.Background() extField := "extra_field" - d1 := &schema.Document{ID: "123", Content: "asd"} - d1.WithDenseVector([]float64{2.3, 4.4}) - field_mapping.SetExtraDataFields(d1, map[string]interface{}{extField: "ext_1"}) - - d2 := &schema.Document{ID: "456", Content: "qwe"} - field_mapping.SetExtraDataFields(d2, map[string]interface{}{extField: "ext_2"}) - + d1 := &schema.Document{ID: "123", Content: "asd", MetaData: map[string]any{extField: "ext_1"}} + d2 := &schema.Document{ID: "456", Content: "qwe", MetaData: map[string]any{extField: "ext_2"}} docs := []*schema.Document{d1, d2} - PatchConvey("test field not found", func() { + PatchConvey("test FieldMapping error", func() { + mockErr := fmt.Errorf("test err") i := &Indexer{ config: &IndexerConfig{ Index: "mock_index", - VectorFields: []field_mapping.FieldKV{ - field_mapping.DefaultFieldKV("not_found_field"), + FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) { + return nil, nil, mockErr }, }, } - bulks, err := i.vectorQueryItems(ctx, docs, &indexer.Options{ + bulks, err := i.makeBulkItems(ctx, docs, &indexer.Options{ Embedding: &mockEmbedding{size: []int{1}, mockVector: []float64{2.1}}, }) - convey.So(err, convey.ShouldBeError, fmt.Sprintf("[vectorQueryItems] field name not found or type incorrect, name=not_found_field, doc=%v", d1)) + convey.So(err, convey.ShouldBeError, fmt.Errorf("[makeBulkItems] FieldMapping failed, %w", mockErr)) convey.So(len(bulks), convey.ShouldEqual, 0) }) PatchConvey("test emb not provided", func() { i := &Indexer{ config: &IndexerConfig{ - Index: "mock_index", - VectorFields: []field_mapping.FieldKV{ - field_mapping.DefaultFieldKV(field_mapping.DocFieldNameContent), - field_mapping.DefaultFieldKV(field_mapping.FieldName(extField)), - }, + Index: "mock_index", + FieldMapping: defaultFieldMapping, }, } - bulks, err := i.vectorQueryItems(ctx, docs, &indexer.Options{Embedding: nil}) - convey.So(err, convey.ShouldBeError, "[vectorQueryItems] embedding not provided") + bulks, err := i.makeBulkItems(ctx, docs, &indexer.Options{Embedding: nil}) + convey.So(err, convey.ShouldBeError, "[makeBulkItems] embedding method not provided") convey.So(len(bulks), convey.ShouldEqual, 0) }) PatchConvey("test vector size invalid", func() { i := &Indexer{ config: &IndexerConfig{ - Index: "mock_index", - VectorFields: []field_mapping.FieldKV{ - field_mapping.DefaultFieldKV(field_mapping.DocFieldNameContent), - field_mapping.DefaultFieldKV(field_mapping.FieldName(extField)), - }, + Index: "mock_index", + FieldMapping: defaultFieldMapping, }, } - bulks, err := i.vectorQueryItems(ctx, docs, &indexer.Options{ + bulks, err := i.makeBulkItems(ctx, docs, &indexer.Options{ Embedding: &mockEmbedding{size: []int{2, 2}, mockVector: []float64{2.1}}, }) - convey.So(err, convey.ShouldBeError, "[vectorQueryItems] invalid vector length, expected=1, got=2") + convey.So(err, convey.ShouldBeError, "[makeBulkItems] invalid vector length, expected=1, got=2") convey.So(len(bulks), convey.ShouldEqual, 0) }) PatchConvey("test success", func() { i := &Indexer{ config: &IndexerConfig{ - Index: "mock_index", - VectorFields: []field_mapping.FieldKV{ - field_mapping.DefaultFieldKV(field_mapping.DocFieldNameContent), - field_mapping.DefaultFieldKV(field_mapping.FieldName(extField)), - }, + Index: "mock_index", + FieldMapping: defaultFieldMapping, }, } - bulks, err := i.vectorQueryItems(ctx, docs, &indexer.Options{ - Embedding: &mockEmbedding{size: []int{1, 2}, mockVector: []float64{2.1}}, + bulks, err := i.makeBulkItems(ctx, docs, &indexer.Options{ + Embedding: &mockEmbedding{size: []int{1, 1}, mockVector: []float64{2.1}}, }) convey.So(err, convey.ShouldBeNil) convey.So(len(bulks), convey.ShouldEqual, 2) exp := []string{ - `{"eino_doc_content":"asd","extra_field":"ext_1","vector_eino_doc_content":[2.3,4.4],"vector_extra_field":[2.1]}`, - `{"eino_doc_content":"qwe","extra_field":"ext_2","vector_eino_doc_content":[2.1],"vector_extra_field":[2.1]}`, + `{"content":"asd","meta_data":{"extra_field":"ext_1"},"vector_content":[2.1]}`, + `{"content":"qwe","meta_data":{"extra_field":"ext_2"},"vector_content":[2.1]}`, } for idx, item := range bulks { convey.So(item.Index, convey.ShouldEqual, i.config.Index) b, err := io.ReadAll(item.Body) + fmt.Println(string(b)) convey.So(err, convey.ShouldBeNil) convey.So(string(b), convey.ShouldEqual, exp[idx]) } @@ -147,3 +133,18 @@ func (m *mockEmbedding) EmbedStrings(ctx context.Context, texts []string, opts . return resp, nil } + +func defaultFieldMapping(ctx context.Context, doc *schema.Document) ( + fields map[string]any, needEmbeddingFields map[string]string, err error) { + + fields = map[string]any{ + "content": doc.Content, + "meta_data": doc.MetaData, + } + + needEmbeddingFields = map[string]string{ + "vector_content": doc.Content, + } + + return fields, needEmbeddingFields, nil +} diff --git a/components/indexer/es8/internal/consts.go b/components/indexer/es8/internal/consts.go deleted file mode 100644 index c515981..0000000 --- a/components/indexer/es8/internal/consts.go +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package internal - -const ( - DocExtraKeyEsFields = "_es_fields" // *schema.Document.MetaData key of es fields except content -) diff --git a/components/indexer/es8/utils.go b/components/indexer/es8/utils.go index a5f8952..e669079 100644 --- a/components/indexer/es8/utils.go +++ b/components/indexer/es8/utils.go @@ -16,27 +16,13 @@ package es8 -import ( - "github.com/cloudwego/eino/schema" - - "github.com/cloudwego/eino-ext/components/indexer/es8/field_mapping" -) - func GetType() string { return typ } -func toESDoc(doc *schema.Document) map[string]any { - mp := make(map[string]any) - if kvs, ok := field_mapping.GetExtraDataFields(doc); ok { - for k, v := range kvs { - mp[k] = v - } - } - - mp[field_mapping.DocFieldNameContent] = doc.Content - - return mp +type tuple[A, B any] struct { + A A + B B } func chunk[T any](slice []T, size int) [][]T { diff --git a/components/retriever/es8/consts.go b/components/retriever/es8/consts.go index 3f1ffcd..7f5da11 100644 --- a/components/retriever/es8/consts.go +++ b/components/retriever/es8/consts.go @@ -17,3 +17,7 @@ package es8 const typ = "ElasticSearch8" + +func GetType() string { + return typ +} diff --git a/components/retriever/es8/field_mapping/consts.go b/components/retriever/es8/field_mapping/consts.go deleted file mode 100644 index abe46ff..0000000 --- a/components/retriever/es8/field_mapping/consts.go +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package field_mapping - -const DocFieldNameContent = "eino_doc_content" diff --git a/components/retriever/es8/field_mapping/mapping.go b/components/retriever/es8/field_mapping/mapping.go deleted file mode 100644 index 6cc35e4..0000000 --- a/components/retriever/es8/field_mapping/mapping.go +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package field_mapping - -import ( - "fmt" - - "github.com/cloudwego/eino-ext/components/retriever/es8/internal" - "github.com/cloudwego/eino/schema" -) - -// GetDefaultVectorFieldKeyContent get default es key for Document.Content -func GetDefaultVectorFieldKeyContent() FieldName { - return defaultVectorFieldKeyContent -} - -// GetDefaultVectorFieldKey generate default vector field name from its field name -func GetDefaultVectorFieldKey(fieldName string) FieldName { - return FieldName(fmt.Sprintf("vector_%s", fieldName)) -} - -// GetExtraDataFields get data fields from *schema.Document -func GetExtraDataFields(doc *schema.Document) (fields map[string]interface{}, ok bool) { - if doc == nil || doc.MetaData == nil { - return nil, false - } - - fields, ok = doc.MetaData[internal.DocExtraKeyEsFields].(map[string]interface{}) - - return fields, ok -} - -type FieldKV struct { - // FieldNameVector vector field name (if needed) - FieldNameVector FieldName `json:"field_name_vector,omitempty"` - // FieldName field name - FieldName FieldName `json:"field_name,omitempty"` - // Value original value - Value string `json:"value,omitempty"` -} - -type FieldName string - -var defaultVectorFieldKeyContent = GetDefaultVectorFieldKey(DocFieldNameContent) diff --git a/components/retriever/es8/go.mod b/components/retriever/es8/go.mod index 5d6f722..d62c84e 100644 --- a/components/retriever/es8/go.mod +++ b/components/retriever/es8/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bytedance/mockey v1.2.13 - github.com/cloudwego/eino v0.3.5 + github.com/cloudwego/eino v0.3.6 github.com/elastic/go-elasticsearch/v8 v8.16.0 github.com/smartystreets/goconvey v1.8.1 github.com/stretchr/testify v1.9.0 diff --git a/components/retriever/es8/go.sum b/components/retriever/es8/go.sum index e701f60..7b9198f 100644 --- a/components/retriever/es8/go.sum +++ b/components/retriever/es8/go.sum @@ -15,6 +15,8 @@ github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/ github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/eino v0.3.5 h1:9PkAOX/phFifrGXkfl4L9rdecxOQJBJY1FtZqF4bz3c= github.com/cloudwego/eino v0.3.5/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo= +github.com/cloudwego/eino v0.3.6 h1:3yfdKKxMVWefdOyGXHuqUMM5cc9iioijj2mpPsDZKIg= +github.com/cloudwego/eino v0.3.6/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/components/retriever/es8/internal/consts.go b/components/retriever/es8/option.go similarity index 67% rename from components/retriever/es8/internal/consts.go rename to components/retriever/es8/option.go index bf3b10d..2dd3f21 100644 --- a/components/retriever/es8/internal/consts.go +++ b/components/retriever/es8/option.go @@ -14,9 +14,14 @@ * limitations under the License. */ -package internal +package es8 -const ( - DocExtraKeyEsFields = "_es_fields" // *schema.Document.MetaData key of es fields except content - DslFilterField = "_dsl_filter_functions" +import ( + "github.com/elastic/go-elasticsearch/v8/typedapi/types" ) + +// ESImplOptions es specified options +// Use retriever.GetImplSpecificOptions[ESImplOptions] to get ESImplOptions from options. +type ESImplOptions struct { + Filters []types.Query `json:"filters,omitempty"` +} diff --git a/components/retriever/es8/retriever.go b/components/retriever/es8/retriever.go index b441f08..b26c519 100644 --- a/components/retriever/es8/retriever.go +++ b/components/retriever/es8/retriever.go @@ -18,19 +18,16 @@ package es8 import ( "context" - "encoding/json" "fmt" "github.com/elastic/go-elasticsearch/v8" "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/schema" - - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" - "github.com/cloudwego/eino-ext/components/retriever/es8/internal" ) type RetrieverConfig struct { @@ -47,11 +44,17 @@ type RetrieverConfig struct { // use search_mode.SearchModeSparseVectorTextExpansion with search_mode.SparseVectorTextExpansionQuery // use search_mode.SearchModeRawStringRequest with json search request SearchMode SearchMode `json:"search_mode"` + // ResultParser parse document from es search hits. + // If ResultParser not provided, defaultResultParser will be used as default + ResultParser func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) // Embedding vectorization method, must provide when SearchMode needed Embedding embedding.Embedder } type SearchMode interface { + // BuildRequest generate search request from config, query and options. + // Additionally, some specified options (like filters for query) will be provided in options, + // and use retriever.GetImplSpecificOptions[options.ESImplOptions] to get it. BuildRequest(ctx context.Context, conf *RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) } @@ -65,6 +68,10 @@ func NewRetriever(_ context.Context, conf *RetrieverConfig) (*Retriever, error) return nil, fmt.Errorf("[NewRetriever] search mode not provided") } + if conf.ResultParser == nil { + return nil, fmt.Errorf("[NewRetriever] result parser not provided") + } + client, err := elasticsearch.NewTypedClient(conf.ESConfig) if err != nil { return nil, fmt.Errorf("[NewRetriever] new es client failed, %w", err) @@ -109,7 +116,7 @@ func (r *Retriever) Retrieve(ctx context.Context, query string, opts ...retrieve return nil, err } - docs, err = r.parseSearchResult(resp) + docs, err = r.parseSearchResult(ctx, resp) if err != nil { return nil, err } @@ -119,40 +126,13 @@ func (r *Retriever) Retrieve(ctx context.Context, query string, opts ...retrieve return docs, nil } -func (r *Retriever) parseSearchResult(resp *search.Response) (docs []*schema.Document, err error) { +func (r *Retriever) parseSearchResult(ctx context.Context, resp *search.Response) (docs []*schema.Document, err error) { docs = make([]*schema.Document, 0, len(resp.Hits.Hits)) for _, hit := range resp.Hits.Hits { - var raw map[string]any - if err = json.Unmarshal(hit.Source_, &raw); err != nil { - return nil, fmt.Errorf("[parseSearchResult] unexpected hit source type, source=%v", string(hit.Source_)) - } - - var id string - if hit.Id_ != nil { - id = *hit.Id_ - } - - content, ok := raw[field_mapping.DocFieldNameContent].(string) - if !ok { - return nil, fmt.Errorf("[parseSearchResult] content type not string, raw=%v", raw) - } - - expMap := make(map[string]any, len(raw)-1) - for k, v := range raw { - if k != internal.DocExtraKeyEsFields { - expMap[k] = v - } - } - - doc := &schema.Document{ - ID: id, - Content: content, - MetaData: map[string]any{internal.DocExtraKeyEsFields: expMap}, - } - - if hit.Score_ != nil { - doc.WithScore(float64(*hit.Score_)) + doc, err := r.config.ResultParser(ctx, hit) + if err != nil { + return nil, err } docs = append(docs, doc) diff --git a/components/retriever/es8/retriever_test.go b/components/retriever/es8/retriever_test.go index f7b2e60..38f9164 100644 --- a/components/retriever/es8/retriever_test.go +++ b/components/retriever/es8/retriever_test.go @@ -19,10 +19,12 @@ package es8 import ( "context" "encoding/json" + "fmt" "testing" "github.com/bytedance/mockey" "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" "github.com/elastic/go-elasticsearch/v8" "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" @@ -34,9 +36,31 @@ func TestNewRetriever(t *testing.T) { t.Run("retrieve_documents", func(t *testing.T) { r, err := NewRetriever(ctx, &RetrieverConfig{ - ESConfig: elasticsearch.Config{}, - Index: "eino_ut", - TopK: 10, + ESConfig: elasticsearch.Config{}, + Index: "eino_ut", + TopK: 10, + ResultParser: func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) { + var mp map[string]any + if err := json.Unmarshal(hit.Source_, &mp); err != nil { + return nil, err + } + + var id string + if hit.Id_ != nil { + id = *hit.Id_ + } + + content, ok := mp["eino_doc_content"].(string) + if !ok { + return nil, fmt.Errorf("content not found") + } + + return &schema.Document{ + ID: id, + Content: content, + MetaData: nil, + }, nil + }, SearchMode: &mockSearchMode{}, }) assert.NoError(t, err) diff --git a/components/retriever/es8/search_mode/approximate.go b/components/retriever/es8/search_mode/approximate.go index ee1825d..87e6e31 100644 --- a/components/retriever/es8/search_mode/approximate.go +++ b/components/retriever/es8/search_mode/approximate.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -18,16 +18,12 @@ package search_mode import ( "context" - "encoding/json" "fmt" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/cloudwego/eino/components/retriever" "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - - "github.com/cloudwego/eino/components/retriever" - - "github.com/cloudwego/eino-ext/components/retriever/es8" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) // SearchModeApproximate retrieve with multiple approximate strategy (filter+knn+rrf) @@ -38,48 +34,34 @@ func SearchModeApproximate(config *ApproximateConfig) es8.SearchMode { } type ApproximateConfig struct { + // QueryFieldName the name of query field, required when using Hybrid + QueryFieldName string + // VectorFieldName the name of the vector field to search against, required + VectorFieldName string // Hybrid if true, add filters and rff to knn query Hybrid bool - // Rrf is a method for combining multiple result sets, is used to + // RRF (Reciprocal Rank Fusion) is a method for combining multiple result sets, is used to // even the score from the knn query and text query - Rrf bool - // RrfRankConstant determines how much influence documents in + RRF bool + // RRFRankConstant determines how much influence documents in // individual result sets per query have over the final ranked result set - RrfRankConstant *int64 - // RrfWindowSize determines the size ptrWithoutZero the individual result sets per query - RrfWindowSize *int64 -} - -type ApproximateQuery struct { - // FieldKV es field info, QueryVectorBuilderModelID will be used if embedding not provided in config, - // and Embedding will be used if QueryVectorBuilderModelID is nil - FieldKV field_mapping.FieldKV `json:"field_kv"` + RRFRankConstant *int64 + // RRFWindowSize determines the size ptrWithoutZero the individual result sets per query + RRFWindowSize *int64 // QueryVectorBuilderModelID the query vector builder model id // see: https://www.elastic.co/guide/en/machine-learning/8.16/ml-nlp-text-emb-vector-search-example.html - QueryVectorBuilderModelID *string `json:"query_vector_builder_model_id,omitempty"` + QueryVectorBuilderModelID *string // Boost Floating point number used to decrease or increase the relevance scores ptrWithoutZero the query. // Boost values are relative to the default value ptrWithoutZero 1.0. // A boost value between 0 and 1.0 decreases the relevance score. // A value greater than 1.0 increases the relevance score. - Boost *float32 `json:"boost,omitempty"` - // Filters for the kNN search query - Filters []types.Query `json:"filters,omitempty"` + Boost *float32 // K The final number ptrWithoutZero nearest neighbors to return as top hits - K *int `json:"k,omitempty"` + K *int // NumCandidates The number ptrWithoutZero nearest neighbor candidates to consider per shard - NumCandidates *int `json:"num_candidates,omitempty"` + NumCandidates *int // Similarity The minimum similarity for a vector to be considered a match - Similarity *float32 `json:"similarity,omitempty"` -} - -// ToRetrieverQuery convert approximate query to string query -func (a *ApproximateQuery) ToRetrieverQuery() (string, error) { - b, err := json.Marshal(a) - if err != nil { - return "", fmt.Errorf("[ToRetrieverQuery] convert query failed, %w", err) - } - - return string(b), nil + Similarity *float32 } type approximate struct { @@ -88,41 +70,38 @@ type approximate struct { func (a *approximate) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) { - options := retriever.GetCommonOptions(&retriever.Options{ + co := retriever.GetCommonOptions(&retriever.Options{ Index: ptrWithoutZero(conf.Index), TopK: ptrWithoutZero(conf.TopK), ScoreThreshold: conf.ScoreThreshold, Embedding: conf.Embedding, }, opts...) - var appReq ApproximateQuery - if err := json.Unmarshal([]byte(query), &appReq); err != nil { - return nil, fmt.Errorf("[BuildRequest][SearchModeApproximate] parse query failed, %w", err) - } + io := retriever.GetImplSpecificOptions[es8.ESImplOptions](nil, opts...) knn := types.KnnSearch{ - Boost: appReq.Boost, - Field: string(appReq.FieldKV.FieldNameVector), - Filter: appReq.Filters, - K: appReq.K, - NumCandidates: appReq.NumCandidates, + Boost: a.config.Boost, + Field: a.config.VectorFieldName, + Filter: io.Filters, + K: a.config.K, + NumCandidates: a.config.NumCandidates, QueryVector: nil, QueryVectorBuilder: nil, - Similarity: appReq.Similarity, + Similarity: a.config.Similarity, } - if appReq.QueryVectorBuilderModelID != nil { + if a.config.QueryVectorBuilderModelID != nil { knn.QueryVectorBuilder = &types.QueryVectorBuilder{TextEmbedding: &types.TextEmbedding{ - ModelId: *appReq.QueryVectorBuilderModelID, - ModelText: appReq.FieldKV.Value, + ModelId: *a.config.QueryVectorBuilderModelID, + ModelText: query, }} } else { - emb := options.Embedding + emb := co.Embedding if emb == nil { return nil, fmt.Errorf("[BuildRequest][SearchModeApproximate] embedding not provided") } - vector, err := emb.EmbedStrings(makeEmbeddingCtx(ctx, emb), []string{appReq.FieldKV.Value}) + vector, err := emb.EmbedStrings(makeEmbeddingCtx(ctx, emb), []string{query}) if err != nil { return nil, fmt.Errorf("[BuildRequest][SearchModeApproximate] embedding failed, %w", err) } @@ -134,32 +113,32 @@ func (a *approximate) BuildRequest(ctx context.Context, conf *es8.RetrieverConfi knn.QueryVector = f64To32(vector[0]) } - req := &search.Request{Knn: []types.KnnSearch{knn}, Size: options.TopK} + req := &search.Request{Knn: []types.KnnSearch{knn}, Size: co.TopK} if a.config.Hybrid { req.Query = &types.Query{ Bool: &types.BoolQuery{ - Filter: appReq.Filters, + Filter: io.Filters, Must: []types.Query{ { Match: map[string]types.MatchQuery{ - string(appReq.FieldKV.FieldName): {Query: appReq.FieldKV.Value}, + a.config.QueryFieldName: {Query: query}, }, }, }, }, } - if a.config.Rrf { + if a.config.RRF { req.Rank = &types.RankContainer{Rrf: &types.RrfRank{ - RankConstant: a.config.RrfRankConstant, - RankWindowSize: a.config.RrfWindowSize, + RankConstant: a.config.RRFRankConstant, + RankWindowSize: a.config.RRFWindowSize, }} } } - if options.ScoreThreshold != nil { - req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold)) + if co.ScoreThreshold != nil { + req.MinScore = (*types.Float64)(ptrWithoutZero(*co.ScoreThreshold)) } return req, nil diff --git a/components/retriever/es8/search_mode/approximate_test.go b/components/retriever/es8/search_mode/approximate_test.go index 36e3afb..897dd94 100644 --- a/components/retriever/es8/search_mode/approximate_test.go +++ b/components/retriever/es8/search_mode/approximate_test.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -22,65 +22,44 @@ import ( "testing" . "github.com/bytedance/mockey" - "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "github.com/smartystreets/goconvey/convey" - + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/retriever" - - "github.com/cloudwego/eino-ext/components/retriever/es8" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/smartystreets/goconvey/convey" ) func TestSearchModeApproximate(t *testing.T) { PatchConvey("test SearchModeApproximate", t, func() { - PatchConvey("test ToRetrieverQuery", func() { - aq := &ApproximateQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - Boost: ptrWithoutZero(float32(1.0)), - K: ptrWithoutZero(10), - NumCandidates: ptrWithoutZero(100), - Similarity: ptrWithoutZero(float32(0.5)), - } - - sq, err := aq.ToRetrieverQuery() - convey.So(err, convey.ShouldBeNil) - convey.So(sq, convey.ShouldEqual, `{"field_kv":{"field_name_vector":"vector_eino_doc_content","field_name":"eino_doc_content","value":"content"},"boost":1,"filters":[{"match":{"label":{"query":"good"}}}],"k":10,"num_candidates":100,"similarity":0.5}`) - }) - PatchConvey("test BuildRequest", func() { ctx := context.Background() + queryFieldName := "eino_doc_content" + vectorFieldName := "vector_eino_doc_content" + query := "content" PatchConvey("test QueryVectorBuilderModelID", func() { - a := &approximate{config: &ApproximateConfig{}} - aq := &ApproximateQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, + a := &approximate{config: &ApproximateConfig{ + QueryFieldName: queryFieldName, + VectorFieldName: vectorFieldName, + Hybrid: false, + RRF: false, + RRFRankConstant: nil, + RRFWindowSize: nil, QueryVectorBuilderModelID: ptrWithoutZero("mock_model"), - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - Boost: ptrWithoutZero(float32(1.0)), - K: ptrWithoutZero(10), - NumCandidates: ptrWithoutZero(100), - Similarity: ptrWithoutZero(float32(0.5)), - } - - sq, err := aq.ToRetrieverQuery() - convey.So(err, convey.ShouldBeNil) + Boost: ptrWithoutZero(float32(1.0)), + K: ptrWithoutZero(10), + NumCandidates: ptrWithoutZero(100), + Similarity: ptrWithoutZero(float32(0.5)), + }} conf := &es8.RetrieverConfig{} - req, err := a.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(nil)) + req, err := a.BuildRequest(ctx, conf, query, + retriever.WithEmbedding(nil), + retriever.WrapImplSpecificOptFn[es8.ESImplOptions](func(o *es8.ESImplOptions) { + o.Filters = []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + } + })) convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) convey.So(err, convey.ShouldBeNil) @@ -88,26 +67,28 @@ func TestSearchModeApproximate(t *testing.T) { }) PatchConvey("test embedding", func() { - a := &approximate{config: &ApproximateConfig{}} - aq := &ApproximateQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - Boost: ptrWithoutZero(float32(1.0)), - K: ptrWithoutZero(10), - NumCandidates: ptrWithoutZero(100), - Similarity: ptrWithoutZero(float32(0.5)), - } + a := &approximate{config: &ApproximateConfig{ + QueryFieldName: queryFieldName, + VectorFieldName: vectorFieldName, + Hybrid: false, + RRF: false, + RRFRankConstant: nil, + RRFWindowSize: nil, + QueryVectorBuilderModelID: nil, + Boost: ptrWithoutZero(float32(1.0)), + K: ptrWithoutZero(10), + NumCandidates: ptrWithoutZero(100), + Similarity: ptrWithoutZero(float32(0.5)), + }} - sq, err := aq.ToRetrieverQuery() - convey.So(err, convey.ShouldBeNil) conf := &es8.RetrieverConfig{} - req, err := a.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}})) + req, err := a.BuildRequest(ctx, conf, query, + retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}), + retriever.WrapImplSpecificOptFn[es8.ESImplOptions](func(o *es8.ESImplOptions) { + o.Filters = []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + } + })) convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) convey.So(err, convey.ShouldBeNil) @@ -116,34 +97,29 @@ func TestSearchModeApproximate(t *testing.T) { PatchConvey("test hybrid with rrf", func() { a := &approximate{config: &ApproximateConfig{ - Hybrid: true, - Rrf: true, - RrfRankConstant: ptrWithoutZero(int64(10)), - RrfWindowSize: ptrWithoutZero(int64(5)), + QueryFieldName: queryFieldName, + VectorFieldName: vectorFieldName, + Hybrid: true, + RRF: true, + RRFRankConstant: ptrWithoutZero(int64(10)), + RRFWindowSize: ptrWithoutZero(int64(5)), + QueryVectorBuilderModelID: nil, + Boost: ptrWithoutZero(float32(1.0)), + K: ptrWithoutZero(10), + NumCandidates: ptrWithoutZero(100), + Similarity: ptrWithoutZero(float32(0.5)), }} - aq := &ApproximateQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - Boost: ptrWithoutZero(float32(1.0)), - K: ptrWithoutZero(10), - NumCandidates: ptrWithoutZero(100), - Similarity: ptrWithoutZero(float32(0.5)), - } - - sq, err := aq.ToRetrieverQuery() - convey.So(err, convey.ShouldBeNil) - conf := &es8.RetrieverConfig{} - req, err := a.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}), + req, err := a.BuildRequest(ctx, conf, query, + retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}), retriever.WithTopK(10), - retriever.WithScoreThreshold(1.1)) + retriever.WithScoreThreshold(1.1), + retriever.WrapImplSpecificOptFn[es8.ESImplOptions](func(o *es8.ESImplOptions) { + o.Filters = []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + } + })) convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) convey.So(err, convey.ShouldBeNil) diff --git a/components/retriever/es8/search_mode/dense_vector_similarity.go b/components/retriever/es8/search_mode/dense_vector_similarity.go index 18e58dd..cd4df4d 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -21,34 +21,16 @@ import ( "encoding/json" "fmt" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/cloudwego/eino/components/retriever" "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - - "github.com/cloudwego/eino/components/retriever" - - "github.com/cloudwego/eino-ext/components/retriever/es8" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) // SearchModeDenseVectorSimilarity calculate embedding similarity between dense_vector field and query // see: https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html#vector-functions -func SearchModeDenseVectorSimilarity(typ DenseVectorSimilarityType) es8.SearchMode { - return &denseVectorSimilarity{script: denseVectorScriptMap[typ]} -} - -type DenseVectorSimilarityQuery struct { - FieldKV field_mapping.FieldKV `json:"field_kv"` - Filters []types.Query `json:"filters,omitempty"` -} - -// ToRetrieverQuery convert approximate query to string query -func (d *DenseVectorSimilarityQuery) ToRetrieverQuery() (string, error) { - b, err := json.Marshal(d) - if err != nil { - return "", fmt.Errorf("[ToRetrieverQuery] convert query failed, %w", err) - } - - return string(b), nil +func SearchModeDenseVectorSimilarity(typ DenseVectorSimilarityType, vectorFieldName string) es8.SearchMode { + return &denseVectorSimilarity{fmt.Sprintf(denseVectorScriptMap[typ], vectorFieldName)} } type denseVectorSimilarity struct { @@ -58,24 +40,21 @@ type denseVectorSimilarity struct { func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) { - options := retriever.GetCommonOptions(&retriever.Options{ + co := retriever.GetCommonOptions(&retriever.Options{ Index: ptrWithoutZero(conf.Index), TopK: ptrWithoutZero(conf.TopK), ScoreThreshold: conf.ScoreThreshold, Embedding: conf.Embedding, }, opts...) - var dq DenseVectorSimilarityQuery - if err := json.Unmarshal([]byte(query), &dq); err != nil { - return nil, fmt.Errorf("[BuildRequest][SearchModeDenseVectorSimilarity] parse query failed, %w", err) - } + io := retriever.GetImplSpecificOptions[es8.ESImplOptions](nil, opts...) - emb := options.Embedding + emb := co.Embedding if emb == nil { return nil, fmt.Errorf("[BuildRequest][SearchModeDenseVectorSimilarity] embedding not provided") } - vector, err := emb.EmbedStrings(makeEmbeddingCtx(ctx, emb), []string{dq.FieldKV.Value}) + vector, err := emb.EmbedStrings(makeEmbeddingCtx(ctx, emb), []string{query}) if err != nil { return nil, fmt.Errorf("[BuildRequest][SearchModeDenseVectorSimilarity] embedding failed, %w", err) } @@ -92,15 +71,15 @@ func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.Retr q := &types.Query{ ScriptScore: &types.ScriptScoreQuery{ Script: types.Script{ - Source: ptrWithoutZero(fmt.Sprintf(d.script, dq.FieldKV.FieldNameVector)), + Source: ptrWithoutZero(d.script), Params: map[string]json.RawMessage{"embedding": vb}, }, }, } - if len(dq.Filters) > 0 { + if len(io.Filters) > 0 { q.ScriptScore.Query = &types.Query{ - Bool: &types.BoolQuery{Filter: dq.Filters}, + Bool: &types.BoolQuery{Filter: io.Filters}, } } else { q.ScriptScore.Query = &types.Query{ @@ -108,9 +87,9 @@ func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.Retr } } - req := &search.Request{Query: q, Size: options.TopK} - if options.ScoreThreshold != nil { - req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold)) + req := &search.Request{Query: q, Size: co.TopK} + if co.ScoreThreshold != nil { + req.MinScore = (*types.Float64)(ptrWithoutZero(*co.ScoreThreshold)) } return req, nil @@ -127,10 +106,10 @@ const ( var denseVectorScriptMap = map[DenseVectorSimilarityType]string{ DenseVectorSimilarityTypeCosineSimilarity: `cosineSimilarity(params.embedding, '%s') + 1.0`, - DenseVectorSimilarityTypeDotProduct: `"" - double value = dotProduct(params.embedding, '%s'); - return sigmoid(1, Math.E, -value); - ""`, + DenseVectorSimilarityTypeDotProduct: ` + double value = dotProduct(params.query_vector, '%s'); + return sigmoid(1, Math.E, -value); + `, DenseVectorSimilarityTypeL1Norm: `1 / (1 + l1norm(params.embedding, '%s'))`, DenseVectorSimilarityTypeL2Norm: `1 / (1 + l2norm(params.embedding, '%s'))`, } diff --git a/components/retriever/es8/search_mode/dense_vector_similarity_test.go b/components/retriever/es8/search_mode/dense_vector_similarity_test.go index a68c333..1a7f548 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity_test.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity_test.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -23,60 +23,30 @@ import ( "testing" . "github.com/bytedance/mockey" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/cloudwego/eino/components/retriever" "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/smartystreets/goconvey/convey" - - "github.com/cloudwego/eino/components/retriever" - - "github.com/cloudwego/eino-ext/components/retriever/es8" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) func TestSearchModeDenseVectorSimilarity(t *testing.T) { PatchConvey("test SearchModeDenseVectorSimilarity", t, func() { - PatchConvey("test ToRetrieverQuery", func() { - dq := &DenseVectorSimilarityQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - } - - sq, err := dq.ToRetrieverQuery() - convey.So(err, convey.ShouldBeNil) - convey.So(sq, convey.ShouldEqual, `{"field_kv":{"field_name_vector":"vector_eino_doc_content","field_name":"eino_doc_content","value":"content"},"filters":[{"match":{"label":{"query":"good"}}}]}`) - }) - PatchConvey("test BuildRequest", func() { ctx := context.Background() - d := &denseVectorSimilarity{script: denseVectorScriptMap[DenseVectorSimilarityTypeCosineSimilarity]} - dq := &DenseVectorSimilarityQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - } - sq, _ := dq.ToRetrieverQuery() + vectorFieldName := "vector_eino_doc_content" + d := SearchModeDenseVectorSimilarity(DenseVectorSimilarityTypeCosineSimilarity, vectorFieldName) + query := "content" PatchConvey("test embedding not provided", func() { - conf := &es8.RetrieverConfig{} - req, err := d.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(nil)) + req, err := d.BuildRequest(ctx, conf, query, retriever.WithEmbedding(nil)) convey.So(err, convey.ShouldBeError, "[BuildRequest][SearchModeDenseVectorSimilarity] embedding not provided") convey.So(req, convey.ShouldBeNil) }) PatchConvey("test vector size invalid", func() { conf := &es8.RetrieverConfig{} - req, err := d.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(mockEmbedding{size: 2, mockVector: []float64{1.1, 1.2}})) + req, err := d.BuildRequest(ctx, conf, query, retriever.WithEmbedding(mockEmbedding{size: 2, mockVector: []float64{1.1, 1.2}})) convey.So(err, convey.ShouldBeError, "[BuildRequest][SearchModeDenseVectorSimilarity] vector size invalid, expect=1, got=2") convey.So(req, convey.ShouldBeNil) }) @@ -84,18 +54,23 @@ func TestSearchModeDenseVectorSimilarity(t *testing.T) { PatchConvey("test success", func() { typ2Exp := map[DenseVectorSimilarityType]string{ DenseVectorSimilarityTypeCosineSimilarity: `{"min_score":1.1,"query":{"script_score":{"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}]}},"script":{"params":{"embedding":[1.1,1.2]},"source":"cosineSimilarity(params.embedding, 'vector_eino_doc_content') + 1.0"}}},"size":10}`, - DenseVectorSimilarityTypeDotProduct: `{"min_score":1.1,"query":{"script_score":{"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}]}},"script":{"params":{"embedding":[1.1,1.2]},"source":"\"\"\n double value = dotProduct(params.embedding, 'vector_eino_doc_content');\n return sigmoid(1, Math.E, -value); \n \"\""}}},"size":10}`, + DenseVectorSimilarityTypeDotProduct: `{"min_score":1.1,"query":{"script_score":{"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}]}},"script":{"params":{"embedding":[1.1,1.2]},"source":"\n double value = dotProduct(params.query_vector, 'vector_eino_doc_content');\n return sigmoid(1, Math.E, -value);\n "}}},"size":10}`, DenseVectorSimilarityTypeL1Norm: `{"min_score":1.1,"query":{"script_score":{"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}]}},"script":{"params":{"embedding":[1.1,1.2]},"source":"1 / (1 + l1norm(params.embedding, 'vector_eino_doc_content'))"}}},"size":10}`, DenseVectorSimilarityTypeL2Norm: `{"min_score":1.1,"query":{"script_score":{"query":{"bool":{"filter":[{"match":{"label":{"query":"good"}}}]}},"script":{"params":{"embedding":[1.1,1.2]},"source":"1 / (1 + l2norm(params.embedding, 'vector_eino_doc_content'))"}}},"size":10}`, } for typ, exp := range typ2Exp { - similarity := &denseVectorSimilarity{script: denseVectorScriptMap[typ]} + similarity := SearchModeDenseVectorSimilarity(typ, vectorFieldName) conf := &es8.RetrieverConfig{} - req, err := similarity.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}), + req, err := similarity.BuildRequest(ctx, conf, query, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}), retriever.WithTopK(10), - retriever.WithScoreThreshold(1.1)) + retriever.WithScoreThreshold(1.1), + retriever.WrapImplSpecificOptFn[es8.ESImplOptions](func(o *es8.ESImplOptions) { + o.Filters = []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + } + })) convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) diff --git a/components/retriever/es8/search_mode/exact_match.go b/components/retriever/es8/search_mode/exact_match.go index 81c7324..0282eef 100644 --- a/components/retriever/es8/search_mode/exact_match.go +++ b/components/retriever/es8/search_mode/exact_match.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -19,20 +19,19 @@ package search_mode import ( "context" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/cloudwego/eino/components/retriever" "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - - "github.com/cloudwego/eino/components/retriever" - - "github.com/cloudwego/eino-ext/components/retriever/es8" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) -func SearchModeExactMatch() es8.SearchMode { - return &exactMatch{} +func SearchModeExactMatch(queryFieldName string) es8.SearchMode { + return &exactMatch{queryFieldName} } -type exactMatch struct{} +type exactMatch struct { + name string +} func (e exactMatch) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) { @@ -46,7 +45,7 @@ func (e exactMatch) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, q := &types.Query{ Match: map[string]types.MatchQuery{ - field_mapping.DocFieldNameContent: {Query: query}, + e.name: {Query: query}, }, } diff --git a/components/retriever/es8/search_mode/exact_match_test.go b/components/retriever/es8/search_mode/exact_match_test.go new file mode 100644 index 0000000..fc90250 --- /dev/null +++ b/components/retriever/es8/search_mode/exact_match_test.go @@ -0,0 +1,41 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package search_mode + +import ( + "context" + "encoding/json" + "testing" + + . "github.com/bytedance/mockey" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/smartystreets/goconvey/convey" +) + +func TestSearchModeExactMatch(t *testing.T) { + PatchConvey("test SearchModeExactMatch", t, func() { + ctx := context.Background() + conf := &es8.RetrieverConfig{} + searchMode := SearchModeExactMatch("test_field") + req, err := searchMode.BuildRequest(ctx, conf, "test_query") + convey.So(err, convey.ShouldBeNil) + b, err := json.Marshal(req) + convey.So(err, convey.ShouldBeNil) + convey.So(string(b), convey.ShouldEqual, `{"query":{"match":{"test_field":{"query":"test_query"}}}}`) + }) + +} diff --git a/components/retriever/es8/search_mode/raw_string.go b/components/retriever/es8/search_mode/raw_string.go index 82a7aa4..01eccd9 100644 --- a/components/retriever/es8/search_mode/raw_string.go +++ b/components/retriever/es8/search_mode/raw_string.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -19,11 +19,9 @@ package search_mode import ( "context" - "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" - - "github.com/cloudwego/eino/components/retriever" - "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/cloudwego/eino/components/retriever" + "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" ) func SearchModeRawStringRequest() es8.SearchMode { diff --git a/components/retriever/es8/search_mode/raw_string_test.go b/components/retriever/es8/search_mode/raw_string_test.go new file mode 100644 index 0000000..75d2619 --- /dev/null +++ b/components/retriever/es8/search_mode/raw_string_test.go @@ -0,0 +1,48 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package search_mode + +import ( + "context" + "testing" + + . "github.com/bytedance/mockey" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/smartystreets/goconvey/convey" +) + +func TestSearchModeRawStringRequest(t *testing.T) { + PatchConvey("test SearchModeRawStringRequest", t, func() { + ctx := context.Background() + conf := &es8.RetrieverConfig{} + searchMode := SearchModeRawStringRequest() + + PatchConvey("test from json error", func() { + r, err := searchMode.BuildRequest(ctx, conf, "test_query") + convey.So(err, convey.ShouldNotBeNil) + convey.So(r, convey.ShouldBeNil) + }) + + PatchConvey("test success", func() { + q := `{"query":{"match":{"test_field":{"query":"test_query"}}}}` + r, err := searchMode.BuildRequest(ctx, conf, q) + convey.So(err, convey.ShouldBeNil) + convey.So(r, convey.ShouldNotBeNil) + convey.So(r.Query.Match["test_field"].Query, convey.ShouldEqual, "test_query") + }) + }) +} diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go index bfa75be..0b9a999 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -18,63 +18,42 @@ package search_mode import ( "context" - "encoding/json" "fmt" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/cloudwego/eino/components/retriever" "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - - "github.com/cloudwego/eino/components/retriever" - - "github.com/cloudwego/eino-ext/components/retriever/es8" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) // SearchModeSparseVectorTextExpansion convert the query text into a list ptrWithoutZero token-weight pairs, // which are then used in a query against a sparse vector // see: https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-text-expansion-query.html -func SearchModeSparseVectorTextExpansion(modelID string) es8.SearchMode { - return &sparseVectorTextExpansion{modelID} -} - -type SparseVectorTextExpansionQuery struct { - FieldKV field_mapping.FieldKV `json:"field_kv"` - Filters []types.Query `json:"filters,omitempty"` -} - -// ToRetrieverQuery convert approximate query to string query -func (s *SparseVectorTextExpansionQuery) ToRetrieverQuery() (string, error) { - b, err := json.Marshal(s) - if err != nil { - return "", fmt.Errorf("[ToRetrieverQuery] convert query failed, %w", err) - } - - return string(b), nil +func SearchModeSparseVectorTextExpansion(modelID, vectorFieldName string) es8.SearchMode { + return &sparseVectorTextExpansion{modelID, vectorFieldName} } type sparseVectorTextExpansion struct { - modelID string + modelID string + vectorFieldName string } func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) { - options := retriever.GetCommonOptions(&retriever.Options{ + co := retriever.GetCommonOptions(&retriever.Options{ Index: ptrWithoutZero(conf.Index), TopK: ptrWithoutZero(conf.TopK), ScoreThreshold: conf.ScoreThreshold, Embedding: conf.Embedding, }, opts...) - var sq SparseVectorTextExpansionQuery - if err := json.Unmarshal([]byte(query), &sq); err != nil { - return nil, fmt.Errorf("[BuildRequest][SearchModeSparseVectorTextExpansion] parse query failed, %w", err) - } + io := retriever.GetImplSpecificOptions[es8.ESImplOptions](nil, opts...) - name := fmt.Sprintf("%s.tokens", sq.FieldKV.FieldNameVector) + name := fmt.Sprintf("%s.tokens", s.vectorFieldName) teq := types.TextExpansionQuery{ ModelId: s.modelID, - ModelText: sq.FieldKV.Value, + ModelText: query, } q := &types.Query{ @@ -82,13 +61,13 @@ func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, conf *es8.R Must: []types.Query{ {TextExpansion: map[string]types.TextExpansionQuery{name: teq}}, }, - Filter: sq.Filters, + Filter: io.Filters, }, } - req := &search.Request{Query: q, Size: options.TopK} - if options.ScoreThreshold != nil { - req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold)) + req := &search.Request{Query: q, Size: co.TopK} + if co.ScoreThreshold != nil { + req.MinScore = (*types.Float64)(ptrWithoutZero(*co.ScoreThreshold)) } return req, nil diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go index f4c42a9..019d0ce 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -22,55 +22,28 @@ import ( "testing" . "github.com/bytedance/mockey" + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/cloudwego/eino/components/retriever" "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/smartystreets/goconvey/convey" - - "github.com/cloudwego/eino/components/retriever" - - "github.com/cloudwego/eino-ext/components/retriever/es8" - "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) func TestSearchModeSparseVectorTextExpansion(t *testing.T) { PatchConvey("test SearchModeSparseVectorTextExpansion", t, func() { - PatchConvey("test ToRetrieverQuery", func() { - sq := &SparseVectorTextExpansionQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - } - - ssq, err := sq.ToRetrieverQuery() - convey.So(err, convey.ShouldBeNil) - convey.So(ssq, convey.ShouldEqual, `{"field_kv":{"field_name_vector":"vector_eino_doc_content","field_name":"eino_doc_content","value":"content"},"filters":[{"match":{"label":{"query":"good"}}}]}`) - - }) - PatchConvey("test BuildRequest", func() { ctx := context.Background() - s := SearchModeSparseVectorTextExpansion("mock_model_id") - sq := &SparseVectorTextExpansionQuery{ - FieldKV: field_mapping.FieldKV{ - FieldNameVector: field_mapping.GetDefaultVectorFieldKeyContent(), - FieldName: field_mapping.DocFieldNameContent, - Value: "content", - }, - Filters: []types.Query{ - {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, - }, - } - - query, _ := sq.ToRetrieverQuery() + vectorFieldName := "vector_eino_doc_content" + s := SearchModeSparseVectorTextExpansion("mock_model_id", vectorFieldName) conf := &es8.RetrieverConfig{} - req, err := s.BuildRequest(ctx, conf, query, + req, err := s.BuildRequest(ctx, conf, "content", retriever.WithTopK(10), - retriever.WithScoreThreshold(1.1)) + retriever.WithScoreThreshold(1.1), + retriever.WrapImplSpecificOptFn[es8.ESImplOptions](func(o *es8.ESImplOptions) { + o.Filters = []types.Query{ + {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, + } + })) convey.So(err, convey.ShouldBeNil) convey.So(req, convey.ShouldNotBeNil) diff --git a/components/retriever/es8/search_mode/utils.go b/components/retriever/es8/search_mode/utils.go index cc54479..ebddb7e 100644 --- a/components/retriever/es8/search_mode/utils.go +++ b/components/retriever/es8/search_mode/utils.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy ptrWithoutZero the License at + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * From 3a78abfd5b2d47deb338ff012a536f10e7378de5 Mon Sep 17 00:00:00 2001 From: xuzhaonan Date: Tue, 14 Jan 2025 11:45:36 +0800 Subject: [PATCH 10/11] refactor: es indexer batch size for embedding --- components/indexer/es8/indexer.go | 149 ++++++++++++-------- components/indexer/es8/indexer_test.go | 180 ++++++++++++++++++------- components/indexer/es8/utils.go | 36 ----- 3 files changed, 220 insertions(+), 145 deletions(-) diff --git a/components/indexer/es8/indexer.go b/components/indexer/es8/indexer.go index ddfebb0..431eb7e 100644 --- a/components/indexer/es8/indexer.go +++ b/components/indexer/es8/indexer.go @@ -33,10 +33,10 @@ import ( ) type IndexerConfig struct { - ESConfig elasticsearch.Config `json:"es_config"` - Index string `json:"index"` - BatchSize int `json:"batch_size"` - + ESConfig elasticsearch.Config `json:"es_config"` + Index string `json:"index"` + // BatchSize controls max texts size for embedding + BatchSize int `json:"batch_size"` // FieldMapping supports customize es fields from eino document, returns: // needEmbeddingFields will be embedded by Embedding firstly, then join fields with its keys, // and joined fields will be saved as bulk item. @@ -85,28 +85,7 @@ func (i *Indexer) Store(ctx context.Context, docs []*schema.Document, opts ...in Embedding: i.config.Embedding, }, opts...) - bi, err := esutil.NewBulkIndexer(esutil.BulkIndexerConfig{ - Index: i.config.Index, - Client: i.client, - }) - if err != nil { - return nil, err - } - - for _, slice := range chunk(docs, i.config.BatchSize) { - items, err := i.makeBulkItems(ctx, slice, options) - if err != nil { - return nil, err - } - - for _, item := range items { - if err = bi.Add(ctx, item); err != nil { - return nil, err - } - } - } - - if err = bi.Close(ctx); err != nil { + if err = i.bulkAdd(ctx, docs, options); err != nil { return nil, err } @@ -117,59 +96,107 @@ func (i *Indexer) Store(ctx context.Context, docs []*schema.Document, opts ...in return ids, nil } -func (i *Indexer) makeBulkItems(ctx context.Context, docs []*schema.Document, options *indexer.Options) (items []esutil.BulkIndexerItem, err error) { +func (i *Indexer) bulkAdd(ctx context.Context, docs []*schema.Document, options *indexer.Options) error { emb := options.Embedding + bi, err := esutil.NewBulkIndexer(esutil.BulkIndexerConfig{ + Index: i.config.Index, + Client: i.client, + }) + if err != nil { + return err + } - items, err = iterWithErr(docs, func(doc *schema.Document) (item esutil.BulkIndexerItem, err error) { - fields, needEmbeddingFields, err := i.config.FieldMapping(ctx, doc) - if err != nil { - return item, fmt.Errorf("[makeBulkItems] FieldMapping failed, %w", err) - } + var ( + tuples []tuple + texts []string + ) - if len(needEmbeddingFields) > 0 { - if emb == nil { - return item, fmt.Errorf("[makeBulkItems] embedding method not provided") - } + embAndAdd := func() error { + var vectors [][]float64 - tuples := make([]tuple[string, int], 0, len(fields)) - texts := make([]string, 0, len(fields)) - for k, text := range needEmbeddingFields { - tuples = append(tuples, tuple[string, int]{k, len(texts)}) - texts = append(texts, text) + if len(texts) > 0 { + if emb == nil { + return fmt.Errorf("[bulkAdd] embedding method not provided") } - vectors, err := emb.EmbedStrings(i.makeEmbeddingCtx(ctx, emb), texts) + vectors, err = emb.EmbedStrings(i.makeEmbeddingCtx(ctx, emb), texts) if err != nil { - return item, fmt.Errorf("[makeBulkItems] embedding failed, %w", err) + return fmt.Errorf("[bulkAdd] embedding failed, %w", err) } if len(vectors) != len(texts) { - return item, fmt.Errorf("[makeBulkItems] invalid vector length, expected=%d, got=%d", len(texts), len(vectors)) + return fmt.Errorf("[bulkAdd] invalid vector length, expected=%d, got=%d", len(texts), len(vectors)) + } + } + + for _, t := range tuples { + fields := t.fields + for k, idx := range t.key2Idx { + fields[k] = vectors[idx] + } + + b, err := json.Marshal(fields) + if err != nil { + return fmt.Errorf("[bulkAdd] marshal bulk item failed, %w", err) } - for _, t := range tuples { - fields[t.A] = vectors[t.B] + if err = bi.Add(ctx, esutil.BulkIndexerItem{ + Index: i.config.Index, + Action: "index", + DocumentID: t.id, + Body: bytes.NewReader(b), + }); err != nil { + return err } } - b, err := json.Marshal(fields) + tuples = tuples[:0] + texts = texts[:0] + + return nil + } + + for idx := range docs { + doc := docs[idx] + fields, needEmbeddingFields, err := i.config.FieldMapping(ctx, doc) if err != nil { - return item, err + return fmt.Errorf("[bulkAdd] FieldMapping failed, %w", err) + } + if fields == nil { + fields = make(map[string]any) } - return esutil.BulkIndexerItem{ - Index: i.config.Index, - Action: "index", - DocumentID: doc.ID, - Body: bytes.NewReader(b), - }, nil - }) + if len(needEmbeddingFields) > i.config.BatchSize { + return fmt.Errorf("[bulkAdd] needEmbeddingFields length over batch size, batch size=%d, got size=%d", + i.config.BatchSize, len(needEmbeddingFields)) + } - if err != nil { - return nil, err + if len(texts)+len(needEmbeddingFields) > i.config.BatchSize { + if err = embAndAdd(); err != nil { + return err + } + } + + key2Idx := make(map[string]int, len(needEmbeddingFields)) + for k, text := range needEmbeddingFields { + key2Idx[k] = len(texts) + texts = append(texts, text) + } + + tuples = append(tuples, tuple{ + id: doc.ID, + fields: fields, + key2Idx: key2Idx, + }) + } + + if len(tuples) > 0 { + if err = embAndAdd(); err != nil { + return err + } } - return items, nil + return bi.Close(ctx) } func (i *Indexer) makeEmbeddingCtx(ctx context.Context, emb embedding.Embedder) context.Context { @@ -193,3 +220,9 @@ func (i *Indexer) GetType() string { func (i *Indexer) IsCallbacksEnabled() bool { return true } + +type tuple struct { + id string + fields map[string]any + key2Idx map[string]int +} diff --git a/components/indexer/es8/indexer_test.go b/components/indexer/es8/indexer_test.go index a1535a4..0ac4187 100644 --- a/components/indexer/es8/indexer_test.go +++ b/components/indexer/es8/indexer_test.go @@ -18,29 +18,50 @@ package es8 import ( "context" + "encoding/json" "fmt" "io" "testing" . "github.com/bytedance/mockey" - "github.com/smartystreets/goconvey/convey" - "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/schema" + "github.com/elastic/go-elasticsearch/v8/esutil" + "github.com/smartystreets/goconvey/convey" ) -func TestVectorQueryItems(t *testing.T) { - PatchConvey("test makeBulkItems", t, func() { +func TestBulkAdd(t *testing.T) { + PatchConvey("test bulkAdd", t, func() { ctx := context.Background() extField := "extra_field" d1 := &schema.Document{ID: "123", Content: "asd", MetaData: map[string]any{extField: "ext_1"}} d2 := &schema.Document{ID: "456", Content: "qwe", MetaData: map[string]any{extField: "ext_2"}} docs := []*schema.Document{d1, d2} + bi, err := esutil.NewBulkIndexer(esutil.BulkIndexerConfig{}) + convey.So(err, convey.ShouldBeNil) + + PatchConvey("test NewBulkIndexer error", func() { + mockErr := fmt.Errorf("test err") + Mock(esutil.NewBulkIndexer).Return(nil, mockErr).Build() + i := &Indexer{ + config: &IndexerConfig{ + Index: "mock_index", + FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) { + return nil, nil, nil + }, + }, + } + err := i.bulkAdd(ctx, docs, &indexer.Options{ + Embedding: &mockEmbedding{size: []int{1}, mockVector: []float64{2.1}}, + }) + convey.So(err, convey.ShouldBeError, mockErr) + }) PatchConvey("test FieldMapping error", func() { mockErr := fmt.Errorf("test err") + Mock(esutil.NewBulkIndexer).Return(bi, nil).Build() i := &Indexer{ config: &IndexerConfig{ Index: "mock_index", @@ -49,78 +70,150 @@ func TestVectorQueryItems(t *testing.T) { }, }, } - - bulks, err := i.makeBulkItems(ctx, docs, &indexer.Options{ + err := i.bulkAdd(ctx, docs, &indexer.Options{ Embedding: &mockEmbedding{size: []int{1}, mockVector: []float64{2.1}}, }) - convey.So(err, convey.ShouldBeError, fmt.Errorf("[makeBulkItems] FieldMapping failed, %w", mockErr)) - convey.So(len(bulks), convey.ShouldEqual, 0) + convey.So(err, convey.ShouldBeError, fmt.Errorf("[bulkAdd] FieldMapping failed, %w", mockErr)) }) - PatchConvey("test emb not provided", func() { + PatchConvey("test len(needEmbeddingFields) > i.config.BatchSize", func() { + Mock(esutil.NewBulkIndexer).Return(bi, nil).Build() i := &Indexer{ config: &IndexerConfig{ - Index: "mock_index", - FieldMapping: defaultFieldMapping, + Index: "mock_index", + BatchSize: 1, + FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) { + return nil, map[string]string{ + "k1": "v1", "k2": "v2", + }, nil + }, }, } + err := i.bulkAdd(ctx, docs, &indexer.Options{ + Embedding: &mockEmbedding{size: []int{1}, mockVector: []float64{2.1}}, + }) + convey.So(err, convey.ShouldBeError, fmt.Errorf("[bulkAdd] needEmbeddingFields length over batch size, batch size=%d, got size=%d", i.config.BatchSize, 2)) + }) - bulks, err := i.makeBulkItems(ctx, docs, &indexer.Options{Embedding: nil}) - convey.So(err, convey.ShouldBeError, "[makeBulkItems] embedding method not provided") - convey.So(len(bulks), convey.ShouldEqual, 0) + PatchConvey("test embedding not provided", func() { + Mock(esutil.NewBulkIndexer).Return(bi, nil).Build() + i := &Indexer{ + config: &IndexerConfig{ + Index: "mock_index", + BatchSize: 2, + FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) { + return map[string]any{ + "k0": "v0", "k1": "v1", "k3": 123, + }, map[string]string{ + "k1": "v1", "k2": "v2", + }, nil + }, + }, + } + err := i.bulkAdd(ctx, docs, &indexer.Options{ + Embedding: nil, + }) + convey.So(err, convey.ShouldBeError, fmt.Errorf("[bulkAdd] embedding method not provided")) }) - PatchConvey("test vector size invalid", func() { + PatchConvey("test embed failed", func() { + mockErr := fmt.Errorf("test err") + Mock(esutil.NewBulkIndexer).Return(bi, nil).Build() i := &Indexer{ config: &IndexerConfig{ - Index: "mock_index", - FieldMapping: defaultFieldMapping, + Index: "mock_index", + BatchSize: 2, + FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) { + return map[string]any{ + "k0": "v0", "k1": "v1", "k3": 123, + }, map[string]string{ + "k1": "v1", "k2": "v2", + }, nil + }, }, } + err := i.bulkAdd(ctx, docs, &indexer.Options{ + Embedding: &mockEmbedding{err: mockErr}, + }) + convey.So(err, convey.ShouldBeError, fmt.Errorf("[bulkAdd] embedding failed, %w", mockErr)) + }) - bulks, err := i.makeBulkItems(ctx, docs, &indexer.Options{ - Embedding: &mockEmbedding{size: []int{2, 2}, mockVector: []float64{2.1}}, + PatchConvey("test len(vectors) != len(texts)", func() { + Mock(esutil.NewBulkIndexer).Return(bi, nil).Build() + i := &Indexer{ + config: &IndexerConfig{ + Index: "mock_index", + BatchSize: 2, + FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) { + return map[string]any{ + "k0": "v0", "k1": "v1", "k3": 123, + }, map[string]string{ + "k1": "v1", "k2": "v2", + }, nil + }, + }, + } + err := i.bulkAdd(ctx, docs, &indexer.Options{ + Embedding: &mockEmbedding{size: []int{1}, mockVector: []float64{2.1}}, }) - convey.So(err, convey.ShouldBeError, "[makeBulkItems] invalid vector length, expected=1, got=2") - convey.So(len(bulks), convey.ShouldEqual, 0) + convey.So(err, convey.ShouldBeError, fmt.Errorf("[bulkAdd] invalid vector length, expected=%d, got=%d", 2, 1)) }) PatchConvey("test success", func() { + var mps []esutil.BulkIndexerItem + Mock(esutil.NewBulkIndexer).Return(bi, nil).Build() + Mock(GetMethod(bi, "Add")).To(func(ctx context.Context, item esutil.BulkIndexerItem) error { + mps = append(mps, item) + return nil + }).Build() + Mock(GetMethod(bi, "Close")).Return(nil).Build() + i := &Indexer{ config: &IndexerConfig{ - Index: "mock_index", - FieldMapping: defaultFieldMapping, + Index: "mock_index", + BatchSize: 2, + FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) { + return map[string]any{ + "k0": doc.Content, "k1": "v1", "k3": 123, + }, map[string]string{ + "k1": "v1", "k2": "v2", + }, nil + }, }, } - - bulks, err := i.makeBulkItems(ctx, docs, &indexer.Options{ - Embedding: &mockEmbedding{size: []int{1, 1}, mockVector: []float64{2.1}}, + err := i.bulkAdd(ctx, docs, &indexer.Options{ + Embedding: &mockEmbedding{size: []int{2, 2}, mockVector: []float64{2.1}}, }) convey.So(err, convey.ShouldBeNil) - convey.So(len(bulks), convey.ShouldEqual, 2) - exp := []string{ - `{"content":"asd","meta_data":{"extra_field":"ext_1"},"vector_content":[2.1]}`, - `{"content":"qwe","meta_data":{"extra_field":"ext_2"},"vector_content":[2.1]}`, - } - - for idx, item := range bulks { - convey.So(item.Index, convey.ShouldEqual, i.config.Index) + convey.So(len(mps), convey.ShouldEqual, 2) + for j, doc := range docs { + item := mps[j] + convey.So(item.DocumentID, convey.ShouldEqual, doc.ID) b, err := io.ReadAll(item.Body) - fmt.Println(string(b)) convey.So(err, convey.ShouldBeNil) - convey.So(string(b), convey.ShouldEqual, exp[idx]) + var mp map[string]interface{} + convey.So(json.Unmarshal(b, &mp), convey.ShouldBeNil) + convey.So(mp["k0"], convey.ShouldEqual, doc.Content) + convey.So(mp["k1"], convey.ShouldEqual, []any{2.1}) + convey.So(mp["k2"], convey.ShouldEqual, []any{2.1}) + convey.So(mp["k3"], convey.ShouldEqual, 123) } }) }) } type mockEmbedding struct { + err error call int size []int mockVector []float64 } func (m *mockEmbedding) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { + if m.err != nil { + return nil, m.err + } + if m.call >= len(m.size) { return nil, fmt.Errorf("call limit error") } @@ -133,18 +226,3 @@ func (m *mockEmbedding) EmbedStrings(ctx context.Context, texts []string, opts . return resp, nil } - -func defaultFieldMapping(ctx context.Context, doc *schema.Document) ( - fields map[string]any, needEmbeddingFields map[string]string, err error) { - - fields = map[string]any{ - "content": doc.Content, - "meta_data": doc.MetaData, - } - - needEmbeddingFields = map[string]string{ - "vector_content": doc.Content, - } - - return fields, needEmbeddingFields, nil -} diff --git a/components/indexer/es8/utils.go b/components/indexer/es8/utils.go index e669079..937e3b5 100644 --- a/components/indexer/es8/utils.go +++ b/components/indexer/es8/utils.go @@ -20,28 +20,6 @@ func GetType() string { return typ } -type tuple[A, B any] struct { - A A - B B -} - -func chunk[T any](slice []T, size int) [][]T { - if size <= 0 { - return nil - } - - var chunks [][]T - for size < len(slice) { - slice, chunks = slice[size:], append(chunks, slice[0:size:size]) - } - - if len(slice) > 0 { - chunks = append(chunks, slice) - } - - return chunks -} - func iter[T, D any](src []T, fn func(T) D) []D { resp := make([]D, len(src)) for i := range src { @@ -50,17 +28,3 @@ func iter[T, D any](src []T, fn func(T) D) []D { return resp } - -func iterWithErr[T, D any](src []T, fn func(T) (D, error)) ([]D, error) { - resp := make([]D, 0, len(src)) - for i := range src { - d, err := fn(src[i]) - if err != nil { - return nil, err - } - - resp = append(resp, d) - } - - return resp, nil -} From 0092596f2181879f3724eaae7f311e93cb9a8825 Mon Sep 17 00:00:00 2001 From: xuzhaonan Date: Tue, 14 Jan 2025 19:47:24 +0800 Subject: [PATCH 11/11] refactor: es indexer field mapping --- components/indexer/es8/indexer.go | 76 ++++++++++++++++++------- components/indexer/es8/indexer_test.go | 79 +++++++++++++++----------- 2 files changed, 102 insertions(+), 53 deletions(-) diff --git a/components/indexer/es8/indexer.go b/components/indexer/es8/indexer.go index 431eb7e..27234f4 100644 --- a/components/indexer/es8/indexer.go +++ b/components/indexer/es8/indexer.go @@ -22,14 +22,13 @@ import ( "encoding/json" "fmt" - "github.com/elastic/go-elasticsearch/v8" - "github.com/elastic/go-elasticsearch/v8/esutil" - "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/schema" + "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/esutil" ) type IndexerConfig struct { @@ -37,16 +36,27 @@ type IndexerConfig struct { Index string `json:"index"` // BatchSize controls max texts size for embedding BatchSize int `json:"batch_size"` - // FieldMapping supports customize es fields from eino document, returns: - // needEmbeddingFields will be embedded by Embedding firstly, then join fields with its keys, - // and joined fields will be saved as bulk item. - FieldMapping func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) + // FieldMapping supports customize es fields from eino document. + // Each key - FieldValue.Value from field2Value will be saved, and + // vector of FieldValue.Value will be saved if FieldValue.EmbedKey is not empty. + DocumentToFields func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error) // Embedding vectorization method, must provide in two cases // 1. VectorFields contains fields except doc Content // 2. VectorFields contains doc Content and vector not provided in doc extra (see Document.Vector method) Embedding embedding.Embedder } +type FieldValue struct { + // Value original Value + Value any + // EmbedKey if set, Value will be vectorized and saved to es. + // If Stringify method is provided, Embedding input text will be Stringify(Value). + // If Stringify method not set, retriever will try to assert Value as string. + EmbedKey string + // Stringify converts Value to string + Stringify func(val any) (string, error) +} + type Indexer struct { client *elasticsearch.Client config *IndexerConfig @@ -58,8 +68,8 @@ func NewIndexer(_ context.Context, conf *IndexerConfig) (*Indexer, error) { return nil, fmt.Errorf("[NewIndexer] new es client failed, %w", err) } - if conf.FieldMapping == nil { - return nil, fmt.Errorf("[NewIndexer] field mapping method not provided") + if conf.DocumentToFields == nil { + return nil, fmt.Errorf("[NewIndexer] DocumentToFields method not provided") } if conf.BatchSize == 0 { @@ -158,34 +168,60 @@ func (i *Indexer) bulkAdd(ctx context.Context, docs []*schema.Document, options for idx := range docs { doc := docs[idx] - fields, needEmbeddingFields, err := i.config.FieldMapping(ctx, doc) + fields, err := i.config.DocumentToFields(ctx, doc) if err != nil { return fmt.Errorf("[bulkAdd] FieldMapping failed, %w", err) } - if fields == nil { - fields = make(map[string]any) + + rawFields := make(map[string]any) + embSize := 0 + for k, v := range fields { + rawFields[k] = v.Value + if v.EmbedKey != "" { + embSize++ + } } - if len(needEmbeddingFields) > i.config.BatchSize { + if embSize > i.config.BatchSize { return fmt.Errorf("[bulkAdd] needEmbeddingFields length over batch size, batch size=%d, got size=%d", - i.config.BatchSize, len(needEmbeddingFields)) + i.config.BatchSize, embSize) } - if len(texts)+len(needEmbeddingFields) > i.config.BatchSize { + if len(texts)+embSize > i.config.BatchSize { if err = embAndAdd(); err != nil { return err } } - key2Idx := make(map[string]int, len(needEmbeddingFields)) - for k, text := range needEmbeddingFields { - key2Idx[k] = len(texts) - texts = append(texts, text) + key2Idx := make(map[string]int, embSize) + for k, v := range fields { + if v.EmbedKey != "" { + if v.EmbedKey == k { + return fmt.Errorf("[bulkAdd] duplicate key for value and vector, field=%s", k) + } + + var text string + if v.Stringify != nil { + text, err = v.Stringify(v.Value) + if err != nil { + return err + } + } else { + var ok bool + text, ok = v.Value.(string) + if !ok { + return fmt.Errorf("[bulkAdd] assert value as string failed, key=%s, emb_key=%s", k, v.EmbedKey) + } + } + + key2Idx[v.EmbedKey] = len(texts) + texts = append(texts, text) + } } tuples = append(tuples, tuple{ id: doc.ID, - fields: fields, + fields: rawFields, key2Idx: key2Idx, }) } diff --git a/components/indexer/es8/indexer_test.go b/components/indexer/es8/indexer_test.go index 0ac4187..59623e6 100644 --- a/components/indexer/es8/indexer_test.go +++ b/components/indexer/es8/indexer_test.go @@ -48,8 +48,8 @@ func TestBulkAdd(t *testing.T) { i := &Indexer{ config: &IndexerConfig{ Index: "mock_index", - FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) { - return nil, nil, nil + DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error) { + return nil, nil }, }, } @@ -65,8 +65,8 @@ func TestBulkAdd(t *testing.T) { i := &Indexer{ config: &IndexerConfig{ Index: "mock_index", - FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) { - return nil, nil, mockErr + DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error) { + return nil, mockErr }, }, } @@ -82,9 +82,10 @@ func TestBulkAdd(t *testing.T) { config: &IndexerConfig{ Index: "mock_index", BatchSize: 1, - FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) { - return nil, map[string]string{ - "k1": "v1", "k2": "v2", + DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error) { + return map[string]FieldValue{ + "k1": {Value: "v1", EmbedKey: "k"}, + "k2": {Value: "v2", EmbedKey: "kk"}, }, nil }, }, @@ -101,12 +102,15 @@ func TestBulkAdd(t *testing.T) { config: &IndexerConfig{ Index: "mock_index", BatchSize: 2, - FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) { - return map[string]any{ - "k0": "v0", "k1": "v1", "k3": 123, - }, map[string]string{ - "k1": "v1", "k2": "v2", - }, nil + DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error) { + return map[string]FieldValue{ + "k0": {Value: "v0"}, + "k1": {Value: "v1", EmbedKey: "vk1"}, + "k2": {Value: 222, EmbedKey: "vk2", Stringify: func(val any) (string, error) { + return "222", nil + }}, + "k3": {Value: 123}, + }, nil }, }, } @@ -123,12 +127,15 @@ func TestBulkAdd(t *testing.T) { config: &IndexerConfig{ Index: "mock_index", BatchSize: 2, - FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) { - return map[string]any{ - "k0": "v0", "k1": "v1", "k3": 123, - }, map[string]string{ - "k1": "v1", "k2": "v2", - }, nil + DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error) { + return map[string]FieldValue{ + "k0": {Value: "v0"}, + "k1": {Value: "v1", EmbedKey: "vk1"}, + "k2": {Value: 222, EmbedKey: "vk2", Stringify: func(val any) (string, error) { + return "222", nil + }}, + "k3": {Value: 123}, + }, nil }, }, } @@ -144,12 +151,15 @@ func TestBulkAdd(t *testing.T) { config: &IndexerConfig{ Index: "mock_index", BatchSize: 2, - FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) { - return map[string]any{ - "k0": "v0", "k1": "v1", "k3": 123, - }, map[string]string{ - "k1": "v1", "k2": "v2", - }, nil + DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error) { + return map[string]FieldValue{ + "k0": {Value: "v0"}, + "k1": {Value: "v1", EmbedKey: "vk1"}, + "k2": {Value: 222, EmbedKey: "vk2", Stringify: func(val any) (string, error) { + return "222", nil + }}, + "k3": {Value: 123}, + }, nil }, }, } @@ -172,12 +182,13 @@ func TestBulkAdd(t *testing.T) { config: &IndexerConfig{ Index: "mock_index", BatchSize: 2, - FieldMapping: func(ctx context.Context, doc *schema.Document) (fields map[string]any, needEmbeddingFields map[string]string, err error) { - return map[string]any{ - "k0": doc.Content, "k1": "v1", "k3": 123, - }, map[string]string{ - "k1": "v1", "k2": "v2", - }, nil + DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]FieldValue, err error) { + return map[string]FieldValue{ + "k0": {Value: doc.Content}, + "k1": {Value: "v1", EmbedKey: "vk1"}, + "k2": {Value: 222, EmbedKey: "vk2", Stringify: func(val any) (string, error) { return "222", nil }}, + "k3": {Value: 123}, + }, nil }, }, } @@ -194,9 +205,11 @@ func TestBulkAdd(t *testing.T) { var mp map[string]interface{} convey.So(json.Unmarshal(b, &mp), convey.ShouldBeNil) convey.So(mp["k0"], convey.ShouldEqual, doc.Content) - convey.So(mp["k1"], convey.ShouldEqual, []any{2.1}) - convey.So(mp["k2"], convey.ShouldEqual, []any{2.1}) + convey.So(mp["k1"], convey.ShouldEqual, "v1") + convey.So(mp["k2"], convey.ShouldEqual, 222) convey.So(mp["k3"], convey.ShouldEqual, 123) + convey.So(mp["vk1"], convey.ShouldEqual, []any{2.1}) + convey.So(mp["vk2"], convey.ShouldEqual, []any{2.1}) } }) })