diff --git a/compose/field_mapping.go b/compose/field_mapping.go new file mode 100644 index 0000000..5700edd --- /dev/null +++ b/compose/field_mapping.go @@ -0,0 +1,228 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "errors" + "fmt" + "reflect" + + "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino/utils/generic" +) + +func takeOne(input any, from string) (any, error) { + if len(from) == 0 { + return input, nil + } + + inputValue := reflect.ValueOf(input) + + f, err := checkAndExtractFromField(from, inputValue) + if err != nil { + return nil, err + } + + return f.Interface(), nil +} + +func assignOne[T any](dest T, taken any, to string) (T, error) { + destValue := reflect.ValueOf(dest) + + if !destValue.CanAddr() { + destValue = reflect.ValueOf(&dest).Elem() + } + + if len(to) == 0 { // assign to output directly + toSet := reflect.ValueOf(taken) + if !toSet.Type().AssignableTo(destValue.Type()) { + return dest, fmt.Errorf("mapping entire value has a mismatched type. from=%v, to=%v", toSet.Type(), destValue.Type()) + } + + destValue.Set(toSet) + + return destValue.Interface().(T), nil + } + + toSet := reflect.ValueOf(taken) + + field, err := checkAndExtractToField(to, destValue, toSet) + if err != nil { + return dest, err + } + + field.Set(toSet) + return destValue.Interface().(T), nil + +} + +func convertTo[T any](mappings map[string]any) (T, error) { + t := generic.NewInstance[T]() + if len(mappings) == 0 { + return t, errors.New("mapper has no Mappings") + } + + var err error + for fieldName, taken := range mappings { + t, err = assignOne(t, taken, fieldName) + if err != nil { + return t, err + } + } + + return t, nil +} + +type fieldMapFn func(any) (map[string]any, error) +type streamFieldMapFn func(streamReader) streamReader + +func mappingAssign[T any](in map[string]any) (any, error) { + return convertTo[T](in) +} + +func mappingStreamAssign[T any](in streamReader) streamReader { + return packStreamReader(schema.StreamReaderWithConvert(in.toAnyStreamReader(), func(v any) (T, error) { + var t T + mappings, ok := v.(map[string]any) + if !ok { + return t, fmt.Errorf("stream mapping expects trunk of type map[Mapping]any, but got %T", v) + } + + return convertTo[T](mappings) + })) +} + +func fieldMap(mappings []*Mapping) fieldMapFn { + return func(input any) (map[string]any, error) { + result := make(map[string]any, len(mappings)) + for _, mapping := range mappings { + taken, err := takeOne(input, mapping.from) + if err != nil { + return nil, err + } + + if _, ok := result[mapping.to]; ok { + return nil, fmt.Errorf("mapping has duplicate to. field=%v", mapping.to) + } + + result[mapping.to] = taken + } + + return result, nil + } +} + +func streamFieldMap(mappings []*Mapping) streamFieldMapFn { + return func(input streamReader) streamReader { + return packStreamReader(schema.StreamReaderWithConvert(input.toAnyStreamReader(), fieldMap(mappings))) + } +} + +var anyType = reflect.TypeOf((*any)(nil)).Elem() + +func checkAndExtractFromField(fromField string, input reflect.Value) (reflect.Value, error) { + if input.Kind() == reflect.Ptr { + input = input.Elem() + } + + if input.Kind() != reflect.Struct { + return reflect.Value{}, fmt.Errorf("mapping has from but input is not struct or struct ptr, type= %v", input.Type()) + } + + f := input.FieldByName(fromField) + if !f.IsValid() { + return reflect.Value{}, fmt.Errorf("mapping has from not found. field=%v, inputType=%v", fromField, input.Type()) + } + + if !f.CanInterface() { + return reflect.Value{}, fmt.Errorf("mapping has from not exported. field= %v, inputType=%v", fromField, input.Type()) + } + + return f, nil +} + +func checkAndExtractToField(toField string, output, toSet reflect.Value) (reflect.Value, error) { + if output.Kind() == reflect.Ptr { + output = output.Elem() + } + + if output.Kind() != reflect.Struct { + return reflect.Value{}, fmt.Errorf("mapping has to but output is not a struct, type=%v", output.Type()) + } + + field := output.FieldByName(toField) + if !field.IsValid() { + return reflect.Value{}, fmt.Errorf("mapping has to not found. field=%v, outputType=%v", toField, output.Type()) + } + + if !field.CanSet() { + return reflect.Value{}, fmt.Errorf("mapping has to not exported. field=%v, outputType=%v", toField, output.Type()) + } + + if !toSet.Type().AssignableTo(field.Type()) { + return reflect.Value{}, fmt.Errorf("mapping to has a mismatched type. field=%s, from=%v, to=%v", toField, toSet.Type(), field.Type()) + } + + return field, nil +} + +func checkAndExtractFieldType(field string, typ reflect.Type) (reflect.Type, error) { + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + if typ.Kind() != reflect.Struct { + return nil, fmt.Errorf("type[%v] is not a struct", typ) + } + + f, ok := typ.FieldByName(field) + if !ok { + return nil, fmt.Errorf("type[%v] has no field[%s]", typ, field) + } + + if !f.IsExported() { + return nil, fmt.Errorf("type[%v] has an unexported field[%s]", typ.String(), field) + } + + return f.Type, nil +} + +func checkMappingGroup(mappings []*Mapping) error { + if len(mappings) <= 1 { + return nil + } + + var toMap = make(map[string]bool, len(mappings)) + + for _, mapping := range mappings { + if mapping.empty() { + return errors.New("multiple mappings have an empty mapping") + } + + if len(mapping.to) == 0 { + return fmt.Errorf("multiple mappings have a mapping to entire output, mapping= %s", mapping) + } + + if _, ok := toMap[mapping.to]; ok { + return fmt.Errorf("multiple mappings have the same To = %s, mappings=%v", mapping.to, mappings) + } + + toMap[mapping.to] = true + } + + return nil +} diff --git a/compose/field_mapping_test.go b/compose/field_mapping_test.go new file mode 100644 index 0000000..8d37d95 --- /dev/null +++ b/compose/field_mapping_test.go @@ -0,0 +1,191 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/utils/generic" +) + +func mapFrom[T any](input any, mappings []*Mapping) (T, error) { + f := fieldMap(mappings) + m, err := f(input) + if err != nil { + var t T + return t, err + } + + a, err := mappingAssign[T](m) + if err != nil { + var t T + return t, err + } + + return a.(T), nil +} + +func TestFieldMapping(t *testing.T) { + t.Run("whole mapped to whole", func(t *testing.T) { + m := []*Mapping{NewMapping("1")} + + out, err := mapFrom[string]("correct input", m) + assert.NoError(t, err) + assert.Equal(t, "correct input", out) + + out, err = mapFrom[string]("", m) + assert.NoError(t, err) + assert.Equal(t, "", out) + + _, err = mapFrom[string](1, m) + assert.ErrorContains(t, err, "mismatched type") + + out1, err := mapFrom[any]("correct input", m) + assert.NoError(t, err) + assert.Equal(t, "correct input", out1) + }) + + t.Run("field mapped to whole", func(t *testing.T) { + type up struct { + F1 string + f2 string + F3 int + } + + m := []*Mapping{NewMapping("1").From("F1")} + + out, err := mapFrom[string](&up{F1: "field1"}, m) + assert.NoError(t, err) + assert.Equal(t, "field1", out) + + out, err = mapFrom[string](&up{F1: ""}, m) + assert.NoError(t, err) + assert.Equal(t, "", out) + + out, err = mapFrom[string](up{F1: "field1"}, m) + assert.NoError(t, err) + assert.Equal(t, "field1", out) + + m[0].from = "f2" + _, err = mapFrom[string](&up{f2: "f2"}, m) + assert.ErrorContains(t, err, "not exported") + + m[0].from = "field3" + _, err = mapFrom[string](&up{F3: 3}, m) + assert.ErrorContains(t, err, "from not found") + + m[0].from = "F3" + _, err = mapFrom[string](&up{F3: 3}, m) + assert.ErrorContains(t, err, "mismatched type") + + m = []*Mapping{NewMapping("1").From("F1")} + out1, err := mapFrom[any](&up{F1: "field1"}, m) + assert.NoError(t, err) + assert.Equal(t, "field1", out1) + }) + + t.Run("whole mapped to field", func(t *testing.T) { + type down struct { + F1 string + f3 string + } + + m := []*Mapping{NewMapping("1").To("F1")} + + out, err := mapFrom[down]("from", m) + assert.NoError(t, err) + assert.Equal(t, down{F1: "from"}, out) + + out, err = mapFrom[down](1, m) + assert.ErrorContains(t, err, "mismatched type") + + m[0].to = "f2" + _, err = mapFrom[down]("from", m) + assert.ErrorContains(t, err, "to not found") + + m[0].to = "f3" + _, err = mapFrom[down]("from", m) + assert.ErrorContains(t, err, "not exported") + + m = []*Mapping{NewMapping("1").To("F1")} + out1, err := mapFrom[*down]("from", m) + assert.NoError(t, err) + assert.Equal(t, &down{F1: "from"}, out1) + }) + + t.Run("field mapped to field", func(t *testing.T) { + type inner struct { + in string + } + + type up struct { + F1 *inner + } + + type down struct { + F1 *inner + } + + m := []*Mapping{NewMapping("1").From("F1").To("F1")} + + out, err := mapFrom[*down](&up{F1: &inner{in: "in"}}, m) + assert.NoError(t, err) + assert.Equal(t, &down{F1: &inner{in: "in"}}, out) + }) + + t.Run("multiple mappings", func(t *testing.T) { + type up struct { + F1 string + F2 int + } + + type down struct { + F1 string + F2 int + } + + m := []*Mapping{ + NewMapping("1").From("F1").To("F1"), + NewMapping("1").From("F2").To("F2"), + } + + out, err := mapFrom[*down](&up{F1: "v1", F2: 2}, m) + assert.NoError(t, err) + assert.Equal(t, &down{F1: "v1", F2: 2}, out) + + m[0].fromNodeKey = "1" + m[0].to = "" + out, err = mapFrom[*down](&up{F1: "v1", F2: 2}, m) + assert.ErrorContains(t, err, "mapping entire value has a mismatched type") + + m = []*Mapping{} + out, err = mapFrom[*down](map[string]any{"key1": "v1", "key2": 2}, m) + assert.ErrorContains(t, err, "mapper has no Mappings") + }) + + t.Run("invalid mapping", func(t *testing.T) { + m := []*Mapping{NewMapping("1").From("F1")} + _, err := mapFrom[string](generic.PtrOf("input"), m) + assert.ErrorContains(t, err, "mapping has from but input is not struct or struct ptr") + + m = []*Mapping{NewMapping("1").To("F1")} + _, err = mapFrom[string]("input", m) + assert.ErrorContains(t, err, "mapping has to but output is not a struct") + }) +} diff --git a/compose/generic_graph.go b/compose/generic_graph.go index 36b541d..fec69e7 100644 --- a/compose/generic_graph.go +++ b/compose/generic_graph.go @@ -80,6 +80,8 @@ func NewGraph[I, O any](opts ...NewGraphOption) *Graph[I, O] { ComponentOfGraph, options.withState, options.withState != nil, + buildConverter[I](), + buildConverter[O](), ), } diff --git a/compose/graph.go b/compose/graph.go index 5168e8a..aa52a44 100644 --- a/compose/graph.go +++ b/compose/graph.go @@ -160,10 +160,16 @@ type graph struct { inputStreamConverter streamConverter outputValueChecker valueChecker outputStreamConverter streamConverter + preConverter *composableRunnable + postConverter *composableRunnable runtimeCheckEdges map[string]map[string]bool runtimeCheckBranches map[string][]bool + edge2FieldMapFn map[string]map[string]fieldMapFn + edge2StreamFieldMapFn map[string]map[string]streamFieldMapFn + node2Mappings map[string][]*Mapping + buildError error cmp component @@ -181,6 +187,7 @@ func newGraph( // nolint: byted_s_args_length_limit cmp component, runCtx func(ctx context.Context) context.Context, enableState bool, + preConverter, postConverter *composableRunnable, ) *graph { return &graph{ nodes: make(map[string]*graphNode), @@ -196,10 +203,16 @@ func newGraph( // nolint: byted_s_args_length_limit inputStreamConverter: inputConv, outputValueChecker: outputChecker, outputStreamConverter: outputConv, + preConverter: preConverter, + postConverter: postConverter, runtimeCheckEdges: make(map[string]map[string]bool), runtimeCheckBranches: make(map[string][]bool), + edge2FieldMapFn: make(map[string]map[string]fieldMapFn), + edge2StreamFieldMapFn: make(map[string]map[string]streamFieldMapFn), + node2Mappings: make(map[string][]*Mapping), + cmp: cmp, runCtx: runCtx, @@ -271,6 +284,10 @@ func (g *graph) addNode(key string, node *graphNode, options *graphAddNodeOpts) // // err := graph.AddEdge("start_node_key", "end_node_key") func (g *graph) AddEdge(startNode, endNode string) (err error) { + return g.addEdgeWithMappings(startNode, endNode) +} + +func (g *graph) addEdgeWithMappings(startNode, endNode string, mappings ...*Mapping) (err error) { if g.buildError != nil { return g.buildError } @@ -307,7 +324,7 @@ func (g *graph) AddEdge(startNode, endNode string) (err error) { return fmt.Errorf("edge end node '%s' needs to be added to graph first", endNode) } - err = g.validateAndInferType(startNode, endNode) + err = g.validateAndInferType(startNode, endNode, mappings...) if err != nil { return err } @@ -327,6 +344,26 @@ func (g *graph) AddEdge(startNode, endNode string) (err error) { return err } + if len(mappings) > 0 { + if _, ok := g.edge2FieldMapFn[startNode]; !ok { + g.edge2FieldMapFn[startNode] = make(map[string]fieldMapFn) + } + + g.edge2FieldMapFn[startNode][endNode] = fieldMap(mappings) + + if _, ok := g.edge2StreamFieldMapFn[startNode]; !ok { + g.edge2StreamFieldMapFn[startNode] = make(map[string]streamFieldMapFn) + } + + g.edge2StreamFieldMapFn[startNode][endNode] = streamFieldMap(mappings) + + if _, ok := g.node2Mappings[endNode]; !ok { + g.node2Mappings[endNode] = make([]*Mapping, 0, len(mappings)) + } + + g.node2Mappings[endNode] = append(g.node2Mappings[endNode], mappings...) + } + return nil } @@ -443,7 +480,7 @@ func (g *graph) AddLambdaNode(key string, node *Lambda, opts ...GraphAddNodeOpt) return g.addNode(key, gNode, options) } -// AddGraphNode add one kind of Graph[I, O]、Chain[I, O]、StateChain[I, O, S] as a node. +// AddGraphNode add one kind of Graph[I, O], Chain[I, O] or Workflow[I, O] as a node. // for Graph[I, O], comes from NewGraph[I, O]() // for Chain[I, O], comes from NewChain[I, O]() func (g *graph) AddGraphNode(key string, node AnyGraph, opts ...GraphAddNodeOpt) error { @@ -543,12 +580,12 @@ func (g *graph) AddBranch(startNode string, branch *GraphBranch) (err error) { return nil } -func (g *graph) validateAndInferType(startNode, endNode string) error { +func (g *graph) validateAndInferType(startNode, endNode string, mappings ...*Mapping) (err error) { startNodeOutputType := g.getNodeOutputType(startNode) endNodeInputType := g.getNodeInputType(endNode) // assume that START and END type isn't empty - // check and update current node. if cannot validate, save edge to toValidateMap + // check and update current node. if it cannot validate, save edge to toValidateMap if startNodeOutputType == nil && endNodeInputType == nil { // type of passthrough have not been inferred yet. defer checking to compile. g.toValidateMap[startNode] = append(g.toValidateMap[startNode], endNode) @@ -560,7 +597,7 @@ func (g *graph) validateAndInferType(startNode, endNode string) error { // start node is passthrough, propagate end node input type to it g.nodes[startNode].cr.inputType = endNodeInputType g.nodes[startNode].cr.outputType = g.nodes[startNode].cr.inputType - } else { + } else if len(mappings) == 0 || mappings[0].empty() { // common node check result := checkAssignable(startNodeOutputType, endNodeInputType) if result == assignableTypeMustNot { @@ -573,6 +610,32 @@ func (g *graph) validateAndInferType(startNode, endNode string) error { } g.runtimeCheckEdges[startNode][endNode] = true } + } else { + if startNodeOutputType == anyType || endNodeInputType == anyType { // input or output is any, can't do any check here, defer to request time + return nil + } + + for _, m := range mappings { + fromType := startNodeOutputType + toType := endNodeInputType + + if len(m.from) > 0 { + if fromType, err = checkAndExtractFieldType(m.from, fromType); err != nil { + return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%s] start node's output failed, %w", startNode, endNode, m, err) + } + } + + if len(m.to) > 0 { + if toType, err = checkAndExtractFieldType(m.to, toType); err != nil { + return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%s] end node's input failed, %w", startNode, endNode, m, err) + } + } + + if checkAssignable(fromType, toType) == assignableTypeMustNot { + return fmt.Errorf("graph edge[%s]-[%s]: after mapping[%s], start node's output type[%s] and end node's input type[%s] mismatch", + m, startNode, endNode, fromType.String(), toType.String()) + } + } } return nil } @@ -712,6 +775,7 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa preProcessor: node.nodeInfo.preProcessor, postProcessor: node.nodeInfo.postProcessor, + preConverter: r.preConverter, } branches := g.branches[name] @@ -723,7 +787,6 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa } chanSubscribeTo[name] = chCall - } invertedEdges := make(map[string][]string) @@ -768,9 +831,15 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa inputStreamConverter: g.inputStreamConverter, outputValueChecker: g.outputValueChecker, outputStreamConverter: g.outputStreamConverter, + preConverter: g.preConverter, + postConverter: g.postConverter, runtimeCheckEdges: g.runtimeCheckEdges, runtimeCheckBranches: g.runtimeCheckBranches, + + edge2FieldMapFn: g.edge2FieldMapFn, + edge2StreamFieldMapFn: g.edge2StreamFieldMapFn, + node2Mappings: g.node2Mappings, } if runType == runTypeDAG { @@ -869,6 +938,7 @@ func (g *graph) toGraphInfo(opt *graphCompileOptions, key2SubGraphs map[string]* Name: gNode.nodeInfo.name, InputKey: gNode.cr.nodeInfo.inputKey, OutputKey: gNode.cr.nodeInfo.outputKey, + Mappings: g.node2Mappings[key], } if gi, ok := key2SubGraphs[key]; ok { diff --git a/compose/graph_node.go b/compose/graph_node.go index 21cb068..d61f64d 100644 --- a/compose/graph_node.go +++ b/compose/graph_node.go @@ -116,6 +116,7 @@ func (gn *graphNode) compileIfNeeded(ctx context.Context) (*composableRunnable, r.meta = gn.executorMeta r.nodeInfo = gn.nodeInfo + r.preConverter = gn.cr.preConverter if gn.nodeInfo.outputKey != "" { r = outputKeyedComposableRunnable(gn.nodeInfo.outputKey, r) diff --git a/compose/graph_run.go b/compose/graph_run.go index db9cd99..92c07fb 100644 --- a/compose/graph_run.go +++ b/compose/graph_run.go @@ -56,6 +56,7 @@ type chanCall struct { writeToBranches []*GraphBranch preProcessor, postProcessor *composableRunnable + preConverter *composableRunnable } type channel interface { @@ -87,8 +88,15 @@ type runner struct { outputValueChecker valueChecker outputStreamConverter streamConverter + preConverter *composableRunnable + postConverter *composableRunnable + runtimeCheckEdges map[string]map[string]bool runtimeCheckBranches map[string][]bool + + edge2FieldMapFn map[string]map[string]fieldMapFn + edge2StreamFieldMapFn map[string]map[string]streamFieldMapFn + node2Mappings map[string][]*Mapping } func (r *runner) toComposableRunnable() *composableRunnable { @@ -113,6 +121,7 @@ func (r *runner) toComposableRunnable() *composableRunnable { inputStreamFilter: r.inputStreamFilter, inputValueChecker: r.inputValueChecker, inputStreamConverter: r.inputStreamConverter, + preConverter: r.preConverter, optionType: nil, // if option type is nil, graph will transmit all options. isPassthrough: false, @@ -226,6 +235,16 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti err error } + taskPreConverter := func(ctx context.Context, t *task) error { + if _, ok := r.node2Mappings[t.nodeKey]; !ok { + return nil + } + + var e error + t.input, e = runWrapper(ctx, t.call.preConverter, t.input, t.option...) + return e + } + taskPreProcessor := func(ctx context.Context, t *task) error { if t.call.preProcessor == nil { return nil @@ -321,6 +340,16 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti if err != nil { return nil, err } + + if _, ok := r.node2Mappings[END]; ok { + value, err := chs[END].get(ctx) + if err != nil { + return nil, err + } + + return runWrapper(ctx, r.postConverter, value) + } + break } @@ -355,6 +384,13 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti return nil, errors.New("no tasks to execute") } + for i := 0; i < len(nextTasks); i++ { + e := taskPreConverter(ctx, nextTasks[i]) + if e != nil { + return nil, fmt.Errorf("pre-convert[%s] input error: %w", nextTasks[i].nodeKey, e) + } + } + for i := 0; i < len(nextTasks); i++ { e := taskPreProcessor(ctx, nextTasks[i]) if e != nil { @@ -466,6 +502,15 @@ func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan } func (r *runner) parserOrValidateTypeIfNeeded(cur, next string, isStream bool, value any) (any, error) { + mapped, done, err := r.doFieldMap(cur, next, isStream, value) + if err != nil { + return nil, err + } + + if done { + return mapped, nil + } + if _, ok := r.runtimeCheckEdges[cur]; !ok { return value, nil } @@ -489,9 +534,40 @@ func (r *runner) parserOrValidateTypeIfNeeded(cur, next string, isStream bool, v value = r.chanSubscribeTo[next].action.inputStreamConverter(value.(streamReader)) return value, nil } - err := r.chanSubscribeTo[next].action.inputValueChecker(value) + err = r.chanSubscribeTo[next].action.inputValueChecker(value) if err != nil { return nil, fmt.Errorf("edge[%s]-[%s] runtime value check fail: %w", cur, next, err) } return value, nil } + +func (r *runner) doFieldMap(cur, next string, isStream bool, value any) (any, bool, error) { + if isStream { + if _, ok := r.edge2StreamFieldMapFn[cur]; !ok { + return value, false, nil + } + + f, ok := r.edge2StreamFieldMapFn[cur][next] + if !ok { + return value, false, nil + } + + return f(value.(streamReader)), true, nil + } + + if _, ok := r.edge2FieldMapFn[cur]; !ok { + return value, false, nil + } + + f, ok := r.edge2FieldMapFn[cur][next] + if !ok { + return value, false, nil + } + + mapped, err := f(value) + if err != nil { + return nil, false, err + } + + return mapped, true, nil +} diff --git a/compose/introspect.go b/compose/introspect.go index 3c29357..dddaa98 100644 --- a/compose/introspect.go +++ b/compose/introspect.go @@ -32,6 +32,8 @@ type GraphNodeInfo struct { Name string InputKey, OutputKey string GraphInfo *GraphInfo + + Mappings []*Mapping // for workflow input mapping } // GraphInfo the info which end users pass in when they are compiling a graph. diff --git a/compose/runnable.go b/compose/runnable.go index b8f8546..36190e8 100644 --- a/compose/runnable.go +++ b/compose/runnable.go @@ -56,6 +56,8 @@ type composableRunnable struct { inputStreamConverter streamConverter inputValueChecker valueChecker + preConverter *composableRunnable + inputType reflect.Type outputType reflect.Type optionType reflect.Type @@ -160,6 +162,7 @@ func (rp *runnablePacker[I, O, TOption]) toComposableRunnable() *composableRunna inputStreamFilter: defaultStreamMapFilter[I], inputStreamConverter: defaultStreamConverter[I], inputValueChecker: defaultValueChecker[I], + preConverter: buildConverter[I](), inputType: inputType, outputType: outputType, optionType: optionType, @@ -203,6 +206,30 @@ func (rp *runnablePacker[I, O, TOption]) toComposableRunnable() *composableRunna return c } +func buildConverter[I any]() *composableRunnable { + inputType := reflect.TypeOf(map[string]any{}) + outputType := generic.TypeOf[I]() + i := func(ctx context.Context, input any, opts ...any) (output any, err error) { + in, ok := input.(map[string]any) + if !ok { + panic(newUnexpectedInputTypeErr(inputType, reflect.TypeOf(input))) + } + + return mappingAssign[I](in) + } + + t := func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) { + return mappingStreamAssign[I](input), nil + } + + return &composableRunnable{ + i: i, + t: t, + inputType: inputType, + outputType: outputType, + } +} + // Invoke works like `ping => pong`. func (rp *runnablePacker[I, O, TOption]) Invoke(ctx context.Context, input I, opts ...TOption) (output O, err error) { @@ -545,6 +572,7 @@ func inputKeyedComposableRunnable(key string, r *composableRunnable) *composable wrapper := *r wrapper.inputValueChecker = defaultValueChecker[map[string]any] wrapper.inputStreamConverter = defaultStreamConverter[map[string]any] + wrapper.preConverter = buildConverter[map[string]any]() i := r.i wrapper.i = func(ctx context.Context, input any, opts ...any) (output any, err error) { v, ok := input.(map[string]any)[key] diff --git a/compose/stream_concat.go b/compose/stream_concat.go index 79fc9ce..d88acd2 100644 --- a/compose/stream_concat.go +++ b/compose/stream_concat.go @@ -263,9 +263,85 @@ func concatSliceValue(val reflect.Value) (reflect.Value, error) { } f := getConcatFunc(elmType) - if f == nil { - return reflect.Value{}, fmt.Errorf("cannot concat value of type %s", elmType) + if f != nil { + return f(val) } - return f(val) + var ( + structType reflect.Type + isStructPtr bool + ) + + if elmType.Kind() == reflect.Struct { + structType = elmType + } else if elmType.Kind() == reflect.Pointer && elmType.Elem().Kind() == reflect.Struct { + isStructPtr = true + structType = elmType.Elem() + } + + if structType != nil { + maps := make([]map[string]any, 0, val.Len()) + for i := 0; i < val.Len(); i++ { + sliceElem := val.Index(i) + m, err := structToMap(sliceElem) + if err != nil { + return reflect.Value{}, err + } + + maps = append(maps, m) + } + + result, err := concatMaps(reflect.ValueOf(maps)) + if err != nil { + return reflect.Value{}, err + } + + return mapToStruct(result.Interface().(map[string]any), structType, isStructPtr), nil + } + + var filtered reflect.Value + for i := 0; i < val.Len(); i++ { + oneVal := val.Index(i) + if !oneVal.IsZero() { + if filtered.IsValid() { + return reflect.Value{}, fmt.Errorf("cannot concat multiple non-zero value of type %s", elmType) + } + + filtered = oneVal + } + } + + return filtered, nil +} + +func structToMap(s reflect.Value) (map[string]any, error) { + if s.Kind() == reflect.Ptr { + s = s.Elem() + } + + ret := make(map[string]any, s.NumField()) + for i := 0; i < s.NumField(); i++ { + fieldType := s.Type().Field(i) + if !fieldType.IsExported() { + return nil, fmt.Errorf("structToMap: field %s is not exported", fieldType.Name) + } + + ret[fieldType.Name] = s.Field(i).Interface() + } + + return ret, nil +} + +func mapToStruct(m map[string]any, t reflect.Type, toPtr bool) reflect.Value { + ret := reflect.New(t).Elem() + for k, v := range m { + field := ret.FieldByName(k) + field.Set(reflect.ValueOf(v)) + } + + if toPtr { + ret = ret.Addr() + } + + return ret } diff --git a/compose/stream_concat_test.go b/compose/stream_concat_test.go index a081efa..25878f7 100644 --- a/compose/stream_concat_test.go +++ b/compose/stream_concat_test.go @@ -186,8 +186,9 @@ func TestConcatError(t *testing.T) { t.Run("not register type", func(t *testing.T) { type y struct{} - _, err := concatItems([]y{{}, {}}) - assert.NotNil(t, err) + out, err := concatItems([]y{{}, {}}) + assert.NoError(t, err) + assert.Equal(t, y{}, out) }) t.Run("map type not equal", func(t *testing.T) { diff --git a/compose/types.go b/compose/types.go index 1d1297f..6cd11f1 100644 --- a/compose/types.go +++ b/compose/types.go @@ -31,6 +31,7 @@ const ( ComponentOfPassthrough component = "Passthrough" ComponentOfToolsNode component = "ToolsNode" ComponentOfLambda component = "Lambda" + ComponentOfWorkflow component = "Workflow" ) // NodeTriggerMode controls the triggering mode of graph nodes. diff --git a/compose/utils.go b/compose/utils.go index c8e1ff0..ad95bba 100644 --- a/compose/utils.go +++ b/compose/utils.go @@ -60,7 +60,13 @@ func mergeValues(vs []any) (any, error) { } if s, ok := vs[0].(streamReader); ok { - if s.getChunkType().Kind() != reflect.Map { + chunkType := s.getChunkType() + if chunkType.Kind() == reflect.Ptr { + if chunkType.Elem().Kind() != reflect.Struct { + return nil, fmt.Errorf("(mergeValues | stream type)"+ + " unsupported chunk type: %v", s.getChunkType()) + } + } else if chunkType.Kind() != reflect.Map && chunkType.Kind() != reflect.Struct { return nil, fmt.Errorf("(mergeValues | stream type)"+ " unsupported chunk type: %v", s.getChunkType()) } diff --git a/compose/workflow.go b/compose/workflow.go new file mode 100644 index 0000000..60577f3 --- /dev/null +++ b/compose/workflow.go @@ -0,0 +1,375 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "errors" + "fmt" + "reflect" + "strings" + + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/embedding" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/utils/generic" +) + +// Mapping is the mapping from one node's output to current node's input. +type Mapping struct { + fromNodeKey string + from string + to string +} + +func (m *Mapping) empty() bool { + return len(m.from) == 0 && len(m.to) == 0 +} + +// From chooses a field value from fromNode's output struct or struct pointer with the specific field name, to serve as the source of the Mapping. +func (m *Mapping) From(name string) *Mapping { + m.from = name + return m +} + +// To chooses a field from currentNode's input struct or struct pointer with the specific field name, to serve as the destination of the Mapping. +func (m *Mapping) To(name string) *Mapping { + m.to = name + return m +} + +// String returns the string representation of the Mapping. +func (m *Mapping) String() string { + var sb strings.Builder + sb.WriteString("from ") + + if m.from != "" { + sb.WriteString(m.from) + sb.WriteString("(field) of ") + } + + sb.WriteString("node '") + sb.WriteString(m.fromNodeKey) + sb.WriteString("'") + + if m.to != "" { + sb.WriteString(" to ") + sb.WriteString(m.to) + sb.WriteString("(field)") + } + + sb.WriteString("; ") + return sb.String() +} + +// NewMapping creates a new Mapping with the specified fromNodeKey. +func NewMapping(fromNodeKey string) *Mapping { + return &Mapping{fromNodeKey: fromNodeKey} +} + +// WorkflowNode is the node of the Workflow. +type WorkflowNode struct { + key string + inputs []*Mapping +} + +// Workflow is wrapper of Graph, replacing AddEdge with declaring Mapping between one node's output and current node's input. +// Under the hood it uses NodeTriggerMode(AllPredecessor), so does not support branches or cycles. +type Workflow[I, O any] struct { + gg *Graph[I, O] + + nodes map[string]*WorkflowNode + end []*Mapping + err error +} + +// NewWorkflow creates a new Workflow. +func NewWorkflow[I, O any](opts ...NewGraphOption) *Workflow[I, O] { + wf := &Workflow[I, O]{ + gg: NewGraph[I, O](opts...), + nodes: make(map[string]*WorkflowNode), + } + + wf.gg.cmp = ComponentOfWorkflow + return wf +} + +func (wf *Workflow[I, O]) Compile(ctx context.Context, opts ...GraphCompileOption) (Runnable[I, O], error) { + if wf.err != nil { + return nil, wf.err + } + + opts = append(opts, WithNodeTriggerMode(AllPredecessor)) + + if err := wf.addEdgesWithMapping(); err != nil { + return nil, err + } + + return wf.gg.Compile(ctx, opts...) +} + +func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.ChatModel, opts ...GraphAddNodeOpt) *WorkflowNode { + node := &WorkflowNode{key: key} + if wf.err != nil { + return node + } + + wf.nodes[key] = node + + err := wf.gg.AddChatModelNode(key, chatModel, opts...) + if err != nil { + wf.err = err + return node + } + return node +} + +func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.ChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode { + node := &WorkflowNode{key: key} + if wf.err != nil { + return node + } + + wf.nodes[key] = node + + err := wf.gg.AddChatTemplateNode(key, chatTemplate, opts...) + if err != nil { + wf.err = err + return node + } + + return node +} + +func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode { + node := &WorkflowNode{key: key} + if wf.err != nil { + return node + } + + wf.nodes[key] = node + + err := wf.gg.AddToolsNode(key, tools, opts...) + if err != nil { + wf.err = err + return node + } + return node +} + +func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retriever, opts ...GraphAddNodeOpt) *WorkflowNode { + node := &WorkflowNode{key: key} + if wf.err != nil { + return node + } + + wf.nodes[key] = node + + err := wf.gg.AddRetrieverNode(key, retriever, opts...) + if err != nil { + wf.err = err + return node + } + return node +} + +func (wf *Workflow[I, O]) AddEmbeddingNode(key string, embedding embedding.Embedder, opts ...GraphAddNodeOpt) *WorkflowNode { + node := &WorkflowNode{key: key} + if wf.err != nil { + return node + } + + wf.nodes[key] = node + + err := wf.gg.AddEmbeddingNode(key, embedding, opts...) + if err != nil { + wf.err = err + return node + } + return node +} + +func (wf *Workflow[I, O]) AddIndexerNode(key string, indexer indexer.Indexer, opts ...GraphAddNodeOpt) *WorkflowNode { + node := &WorkflowNode{key: key} + if wf.err != nil { + return node + } + + wf.nodes[key] = node + + err := wf.gg.AddIndexerNode(key, indexer, opts...) + if err != nil { + wf.err = err + return node + } + return node +} + +func (wf *Workflow[I, O]) AddLoaderNode(key string, loader document.Loader, opts ...GraphAddNodeOpt) *WorkflowNode { + node := &WorkflowNode{key: key} + if wf.err != nil { + return node + } + + wf.nodes[key] = node + + err := wf.gg.AddLoaderNode(key, loader, opts...) + if err != nil { + wf.err = err + return node + } + return node +} + +func (wf *Workflow[I, O]) AddDocumentTransformerNode(key string, transformer document.Transformer, opts ...GraphAddNodeOpt) *WorkflowNode { + node := &WorkflowNode{key: key} + if wf.err != nil { + return node + } + + wf.nodes[key] = node + + err := wf.gg.AddDocumentTransformerNode(key, transformer, opts...) + if err != nil { + wf.err = err + return node + } + return node +} + +func (wf *Workflow[I, O]) AddGraphNode(key string, graph AnyGraph, opts ...GraphAddNodeOpt) *WorkflowNode { + node := &WorkflowNode{key: key} + if wf.err != nil { + return node + } + + wf.nodes[key] = node + + err := wf.gg.AddGraphNode(key, graph, opts...) + if err != nil { + wf.err = err + return node + } + return node +} + +func (wf *Workflow[I, O]) AddLambdaNode(key string, lambda *Lambda, opts ...GraphAddNodeOpt) *WorkflowNode { + node := &WorkflowNode{key: key} + if wf.err != nil { + return node + } + + wf.nodes[key] = node + + err := wf.gg.AddLambdaNode(key, lambda, opts...) + if err != nil { + wf.err = err + return node + } + + return node +} + +func (n *WorkflowNode) AddInput(inputs ...*Mapping) *WorkflowNode { + n.inputs = append(n.inputs, inputs...) + return n +} + +func (wf *Workflow[I, O]) AddEnd(inputs ...*Mapping) { + wf.end = inputs +} + +func (wf *Workflow[I, O]) compile(ctx context.Context, options *graphCompileOptions) (*composableRunnable, error) { + options.nodeTriggerMode = AllPredecessor + if err := wf.addEdgesWithMapping(); err != nil { + return nil, err + } + return wf.gg.compile(ctx, options) +} + +func (wf *Workflow[I, O]) inputType() reflect.Type { + return generic.TypeOf[I]() +} + +func (wf *Workflow[I, O]) outputType() reflect.Type { + return generic.TypeOf[O]() +} + +func (wf *Workflow[I, O]) component() component { + return wf.gg.component() +} + +func (wf *Workflow[I, O]) addEdgesWithMapping() (err error) { + var toNode string + for _, node := range wf.nodes { + toNode = node.key + if len(node.inputs) == 0 { + return fmt.Errorf("workflow node = %s has no input", toNode) + } + + fromNode2Mappings := make(map[string][]*Mapping, len(node.inputs)) + for i := range node.inputs { + input := node.inputs[i] + fromNodeKey := input.fromNodeKey + fromNode2Mappings[fromNodeKey] = append(fromNode2Mappings[fromNodeKey], input) + } + + for fromNode, mappings := range fromNode2Mappings { + if err = checkMappingGroup(mappings); err != nil { + return err + } + + if mappings[0].empty() { + if err = wf.gg.AddEdge(fromNode, toNode); err != nil { + return err + } + } else if err = wf.gg.addEdgeWithMappings(fromNode, toNode, mappings...); err != nil { + return err + } + } + } + + if len(wf.end) == 0 { + return errors.New("workflow END has no input mapping") + } + + fromNode2EndMappings := make(map[string][]*Mapping, len(wf.end)) + for i := range wf.end { + input := wf.end[i] + fromNodeKey := input.fromNodeKey + fromNode2EndMappings[fromNodeKey] = append(fromNode2EndMappings[fromNodeKey], input) + } + + for fromNode, mappings := range fromNode2EndMappings { + if err = checkMappingGroup(mappings); err != nil { + return err + } + + if mappings[0].empty() { + if err = wf.gg.AddEdge(fromNode, END); err != nil { + return err + } + } else if err = wf.gg.addEdgeWithMappings(fromNode, END, mappings...); err != nil { + return err + } + } + + return nil +} diff --git a/compose/workflow_test.go b/compose/workflow_test.go new file mode 100644 index 0000000..79924cf --- /dev/null +++ b/compose/workflow_test.go @@ -0,0 +1,323 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "fmt" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/cloudwego/eino/internal/mock/components/embedding" + "github.com/cloudwego/eino/internal/mock/components/indexer" + "github.com/cloudwego/eino/internal/mock/components/model" + "github.com/cloudwego/eino/schema" +) + +func TestWorkflow(t *testing.T) { + ctx := context.Background() + + type structA struct { + Field1 string + Field2 int + Field3 []any + } + + type structB struct { + Field1 string + Field2 int + } + + type structC struct { + Field1 string + } + + type structE struct { + Field1 string + Field2 string + Field3 []any + } + + type structF struct { + Field1 string + Field2 string + Field3 []any + B int + StateTemp string + } + + type state struct { + temp string + } + + type structEnd struct { + Field1 string + } + + subGraph := NewGraph[string, *structB]() + _ = subGraph.AddLambdaNode( + "1", + InvokableLambda(func(ctx context.Context, input string) (*structB, error) { + return &structB{Field1: input, Field2: 33}, nil + }), + ) + _ = subGraph.AddEdge(START, "1") + _ = subGraph.AddEdge("1", END) + + subChain := NewChain[any, any](). + AppendLambda(InvokableLambda(func(_ context.Context, in any) (any, error) { + return &structC{Field1: fmt.Sprintf("%d", in)}, nil + })) + + subWorkflow := NewWorkflow[[]any, []any]() + subWorkflow.AddLambdaNode( + "1", + InvokableLambda(func(_ context.Context, in []any) ([]any, error) { + return in, nil + }), + WithOutputKey("key")). + AddInput(NewMapping(START)) + subWorkflow.AddLambdaNode( + "2", + InvokableLambda(func(_ context.Context, in []any) ([]any, error) { + return in, nil + }), + WithInputKey("key")). + AddInput(NewMapping("1")) + subWorkflow.AddEnd(NewMapping("2")) + + w := NewWorkflow[*structA, *structEnd](WithGenLocalState(func(context.Context) *state { return &state{} })) + + w. + AddGraphNode("B", subGraph, + WithStatePostHandler(func(ctx context.Context, out *structB, state *state) (*structB, error) { + state.temp = out.Field1 + return out, nil + })). + AddInput(NewMapping(START).From("Field1")) + + w. + AddGraphNode("C", subChain). + AddInput(NewMapping(START).From("Field2")) + + w. + AddGraphNode("D", subWorkflow). + AddInput(NewMapping(START).From("Field3")) + + w. + AddLambdaNode( + "E", + TransformableLambda(func(_ context.Context, in *schema.StreamReader[structE]) (*schema.StreamReader[structE], error) { + return schema.StreamReaderWithConvert(in, func(in structE) (structE, error) { + if len(in.Field1) > 0 { + in.Field1 = "E:" + in.Field1 + } + if len(in.Field2) > 0 { + in.Field2 = "E:" + in.Field2 + } + + return in, nil + }), nil + }), + WithStreamStatePreHandler(func(ctx context.Context, in *schema.StreamReader[structE], state *state) (*schema.StreamReader[structE], error) { + temp := state.temp + return schema.StreamReaderWithConvert(in, func(v structE) (structE, error) { + if len(v.Field3) > 0 { + v.Field3 = append(v.Field3, "Pre:"+temp) + } + + return v, nil + }), nil + }), + WithStreamStatePostHandler(func(ctx context.Context, out *schema.StreamReader[structE], state *state) (*schema.StreamReader[structE], error) { + return schema.StreamReaderWithConvert(out, func(v structE) (structE, error) { + if len(v.Field1) > 0 { + v.Field1 = v.Field1 + "+Post" + } + return v, nil + }), nil + })). + AddInput( + NewMapping("B").From("Field1").To("Field1"), + NewMapping("C").From("Field1").To("Field2"), + NewMapping("D").To("Field3"), + ) + + w. + AddLambdaNode( + "F", + InvokableLambda(func(ctx context.Context, in *structF) (string, error) { + return fmt.Sprintf("%v_%v_%v_%v_%v", in.Field1, in.Field2, in.Field3, in.B, in.StateTemp), nil + }), + WithStatePreHandler(func(ctx context.Context, in *structF, state *state) (*structF, error) { + in.StateTemp = state.temp + return in, nil + }), + ). + AddInput( + NewMapping("B").From("Field2").To("B"), + NewMapping("E").From("Field1").To("Field1"), + NewMapping("E").From("Field2").To("Field2"), + NewMapping("E").From("Field3").To("Field3"), + ) + + w.AddEnd(NewMapping("F").To("Field1")) + + compiled, err := w.Compile(ctx) + assert.NoError(t, err) + + input := &structA{ + Field1: "1", + Field2: 2, + Field3: []any{ + 1, "good", + }, + } + out, err := compiled.Invoke(ctx, input) + assert.NoError(t, err) + assert.Equal(t, &structEnd{"E:1+Post_E:2_[1 good Pre:1]_33_1"}, out) + + outStream, err := compiled.Stream(ctx, input) + assert.NoError(t, err) + defer outStream.Close() + for { + chunk, err := outStream.Recv() + if err != nil { + if err == io.EOF { + break + } + + t.Error(err) + return + } + + assert.Equal(t, &structEnd{"E:1+Post_E:2_[1 good Pre:1]_33_1"}, chunk) + } +} + +func TestWorkflowCompile(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + + t.Run("a node has no input", func(t *testing.T) { + w := NewWorkflow[string, string]() + w.AddToolsNode("1", &ToolsNode{}) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "workflow node = 1 has no input") + }) + + t.Run("compile without add end", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput(NewMapping(START)) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "workflow END has no input mapping") + }) + + t.Run("type mismatch", func(t *testing.T) { + w := NewWorkflow[string, string]() + w.AddToolsNode("1", &ToolsNode{}).AddInput(NewMapping(START)) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "mismatch") + }) + + t.Run("upstream not struct/struct ptr, mapping has FromField", func(t *testing.T) { + w := NewWorkflow[[]*schema.Document, []string]() + + w.AddIndexerNode("indexer", indexer.NewMockIndexer(ctrl)).AddInput(NewMapping(START).From("F1")) + w.AddEnd(NewMapping("indexer")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[[]*schema.Document] is not a struct") + }) + + t.Run("downstream not struct/struct ptr, mapping has ToField", func(t *testing.T) { + w := NewWorkflow[[]string, [][]float64]() + w.AddEmbeddingNode("embedder", embedding.NewMockEmbedder(ctrl)).AddInput(NewMapping(START).To("F1")) + w.AddEnd(NewMapping("embedder")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[[]string] is not a struct") + }) + + t.Run("map to non existing field in upstream", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("tools_node", &ToolsNode{}).AddInput(NewMapping(START).From("non_exist")) + w.AddEnd(NewMapping("tools_node")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[schema.Message] has no field[non_exist]") + }) + + t.Run("map to not exported field in downstream", func(t *testing.T) { + w := NewWorkflow[string, *Mapping]() + w.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input, nil + })).AddInput(NewMapping(START)) + w.AddEnd(NewMapping("1").To("to")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[compose.Mapping] has an unexported field[to]") + }) + + t.Run("duplicate node key", func(t *testing.T) { + w := NewWorkflow[[]*schema.Message, []*schema.Message]() + w.AddChatModelNode("1", model.NewMockChatModel(ctrl)).AddInput(NewMapping(START)) + w.AddToolsNode("1", &ToolsNode{}).AddInput(NewMapping("1")) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "node '1' already present") + }) + + t.Run("from non-existing node", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput(NewMapping(START)) + w.AddEnd(NewMapping("2")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "edge start node '2' needs to be added to graph first") + }) + + t.Run("multiple mappings have an empty mapping", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput(NewMapping(START), NewMapping(START).From("Content").To("Content")) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "have an empty mapping") + }) + + t.Run("multiple mappings have mapping to entire output ", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput( + NewMapping(START).From("Role"), + NewMapping(START).From("Content"), + ) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "have a mapping to entire output") + }) + + t.Run("multiple mappings have duplicate ToField", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput( + NewMapping(START).From("Content").To("Content"), + NewMapping(START).From("Role").To("Content"), + ) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "have the same To") + }) +}