diff --git a/compose/field_mapping.go b/compose/field_mapping.go index 49c6f15..e235f8f 100644 --- a/compose/field_mapping.go +++ b/compose/field_mapping.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package compose import ( @@ -10,18 +26,18 @@ import ( ) func takeOne(input any, m *Mapping) (any, error) { - if len(m.FromField) == 0 && len(m.FromMapKey) == 0 { + if len(m.fromField) == 0 && len(m.fromMapKey) == 0 { return input, nil } - if len(m.FromField) > 0 && len(m.FromMapKey) > 0 { - return nil, fmt.Errorf("mapping has both FromField and FromMapKey, m=%+v", m) + if len(m.fromField) > 0 && len(m.fromMapKey) > 0 { + return nil, fmt.Errorf("mapping has both fromField and fromMapKey, m=%s", m) } inputValue := reflect.ValueOf(input) - if len(m.FromField) > 0 { - f, err := checkAndExtractFromField(m.FromField, inputValue) + if len(m.fromField) > 0 { + f, err := checkAndExtractFromField(m.fromField, inputValue) if err != nil { return nil, err } @@ -29,7 +45,7 @@ func takeOne(input any, m *Mapping) (any, error) { return f.Interface(), nil } - v, err := checkAndExtractFromMapKey(m.FromMapKey, inputValue) + v, err := checkAndExtractFromMapKey(m.fromMapKey, inputValue) if err != nil { return nil, err } @@ -44,7 +60,7 @@ func assignOne[T any](dest T, taken any, m *Mapping) (T, error) { destValue = reflect.ValueOf(&dest).Elem() } - if len(m.ToField) == 0 && len(m.ToMapKey) == 0 { // assign to output directly + if len(m.toField) == 0 && len(m.toMapKey) == 0 { // assign to output directly toSet := reflect.ValueOf(taken) if !toSet.Type().AssignableTo(destValue.Type()) { return dest, fmt.Errorf("mapping entire value has a mismatched type. from=%v, to=%v", toSet.Type(), destValue.Type()) @@ -55,14 +71,14 @@ func assignOne[T any](dest T, taken any, m *Mapping) (T, error) { return destValue.Interface().(T), nil } - if len(m.ToField) > 0 && len(m.ToMapKey) > 0 { - return dest, fmt.Errorf("mapping has both ToField and ToMapKey, m=%+v", m) + if len(m.toField) > 0 && len(m.toMapKey) > 0 { + return dest, fmt.Errorf("mapping has both toField and toMapKey, m=%s", m) } toSet := reflect.ValueOf(taken) - if len(m.ToField) > 0 { - field, err := checkAndExtractToField(m.ToField, destValue, toSet) + if len(m.toField) > 0 { + field, err := checkAndExtractToField(m.toField, destValue, toSet) if err != nil { return dest, err } @@ -71,7 +87,7 @@ func assignOne[T any](dest T, taken any, m *Mapping) (T, error) { return destValue.Interface().(T), nil } - key, err := checkAndExtractToMapKey(m.ToMapKey, destValue, toSet) + key, err := checkAndExtractToMapKey(m.toMapKey, destValue, toSet) if err != nil { return dest, err } @@ -87,16 +103,16 @@ func mapFrom[T any](input any, mappings []*Mapping) (T, error) { return t, errors.New("mapper has no Mappings") } - from := mappings[0].From + from := mappings[0].fromNodeKey for _, mapping := range mappings { - if len(mapping.ToField) == 0 && len(mapping.ToMapKey) == 0 { + if len(mapping.toField) == 0 && len(mapping.toMapKey) == 0 { if len(mappings) > 1 { return t, fmt.Errorf("one of the mapping maps to entire input, conflict") } } - if mapping.From != from { - return t, fmt.Errorf("multiple mappings from the same node have different keys: %s, %s", mapping.From, from) + if mapping.fromNodeKey != from { + return t, fmt.Errorf("multiple mappings from the same node have different keys: %s, %s", mapping.fromNodeKey, from) } taken, err := takeOne(input, mapping) @@ -150,16 +166,16 @@ func checkAndExtractFromField(fromField string, input reflect.Value) (reflect.Va } if input.Kind() != reflect.Struct { - return reflect.Value{}, fmt.Errorf("mapping has FromField but input is not struct or struct ptr, type= %v", input.Type()) + return reflect.Value{}, fmt.Errorf("mapping has fromField but input is not struct or struct ptr, type= %v", input.Type()) } f := input.FieldByName(fromField) if !f.IsValid() { - return reflect.Value{}, fmt.Errorf("mapping has FromField not found. field=%v, inputType=%v", fromField, input.Type()) + return reflect.Value{}, fmt.Errorf("mapping has fromField not found. field=%v, inputType=%v", fromField, input.Type()) } if !f.CanInterface() { - return reflect.Value{}, fmt.Errorf("mapping has FromField not exported. field= %v, inputType=%v", fromField, input.Type()) + return reflect.Value{}, fmt.Errorf("mapping has fromField not exported. field= %v, inputType=%v", fromField, input.Type()) } return f, nil @@ -176,7 +192,7 @@ func checkAndExtractFromMapKey(fromMapKey string, input reflect.Value) (reflect. v := input.MapIndex(reflect.ValueOf(fromMapKey)) if !v.IsValid() { - return reflect.Value{}, fmt.Errorf("mapping FromMapKey not found in input. key=%s, inputType= %v", fromMapKey, input.Type()) + return reflect.Value{}, fmt.Errorf("mapping fromMapKey not found in input. key=%s, inputType= %v", fromMapKey, input.Type()) } return v, nil @@ -188,20 +204,20 @@ func checkAndExtractToField(toField string, output, toSet reflect.Value) (reflec } if output.Kind() != reflect.Struct { - return reflect.Value{}, fmt.Errorf("mapping has ToField but output is not a struct, type=%v", output.Type()) + return reflect.Value{}, fmt.Errorf("mapping has toField but output is not a struct, type=%v", output.Type()) } field := output.FieldByName(toField) if !field.IsValid() { - return reflect.Value{}, fmt.Errorf("mapping has ToField not found. field=%v, outputType=%v", toField, output.Type()) + return reflect.Value{}, fmt.Errorf("mapping has toField not found. field=%v, outputType=%v", toField, output.Type()) } if !field.CanSet() { - return reflect.Value{}, fmt.Errorf("mapping has ToField not exported. field=%v, outputType=%v", toField, output.Type()) + return reflect.Value{}, fmt.Errorf("mapping has toField not exported. field=%v, outputType=%v", toField, output.Type()) } if !toSet.Type().AssignableTo(field.Type()) { - return reflect.Value{}, fmt.Errorf("mapping ToField has a mismatched type. field=%s, from=%v, to=%v", toField, toSet.Type(), field.Type()) + return reflect.Value{}, fmt.Errorf("mapping toField has a mismatched type. field=%s, from=%v, to=%v", toField, toSet.Type(), field.Type()) } return field, nil @@ -209,21 +225,21 @@ func checkAndExtractToField(toField string, output, toSet reflect.Value) (reflec func checkAndExtractToMapKey(toMapKey string, output, toSet reflect.Value) (reflect.Value, error) { if output.Kind() != reflect.Map { - return reflect.Value{}, fmt.Errorf("mapping has ToMapKey but output is not a map, type=%v", output.Type()) + return reflect.Value{}, fmt.Errorf("mapping has toMapKey but output is not a map, type=%v", output.Type()) } if !reflect.TypeOf(toMapKey).AssignableTo(output.Type().Key()) { - return reflect.Value{}, fmt.Errorf("mapping has ToMapKey but output is not a map with string key, type=%v", output.Type()) + return reflect.Value{}, fmt.Errorf("mapping has toMapKey but output is not a map with string key, type=%v", output.Type()) } if !toSet.Type().AssignableTo(output.Type().Elem()) { - return reflect.Value{}, fmt.Errorf("mapping ToMapKey has a mismatched type. key=%s, from=%v, to=%v", toMapKey, toSet.Type(), output.Type().Elem()) + return reflect.Value{}, fmt.Errorf("mapping toMapKey has a mismatched type. key=%s, from=%v, to=%v", toMapKey, toSet.Type(), output.Type().Elem()) } return reflect.ValueOf(toMapKey), nil } -func checkAndExtractMapValueType(mapKey string, typ reflect.Type) (reflect.Type, error) { +func checkAndExtractMapValueType(typ reflect.Type) (reflect.Type, error) { if typ.Kind() != reflect.Map { return nil, fmt.Errorf("type[%v] is not a map", typ) } @@ -261,65 +277,75 @@ func checkMappingGroup(mappings []*Mapping) error { return nil } - from := mappings[0].From - var fromMapKeyFlag, toMapKeyFlag, fromFieldFlag, toFieldFlag bool - fromMap := make(map[string]bool, len(mappings)) - toMap := make(map[string]bool, len(mappings)) + var ( + fromMapKeyFlag, toMapKeyFlag, fromFieldFlag, toFieldFlag bool + fromMap = make(map[string]bool, len(mappings)) + toMap = make(map[string]bool, len(mappings)) + ) + for _, mapping := range mappings { - if mapping.From != from { - return fmt.Errorf("multiple mappings from the same group have from node keys: %s, %s", mapping.From, from) + if mapping.empty() { + return errors.New("multiple mappings have an empty mapping") + } + + if len(mapping.fromField) == 0 && len(mapping.fromMapKey) == 0 { + return fmt.Errorf("multiple mappings have a mapping from entire input, mapping= %s", mapping) + } + + if len(mapping.toField) == 0 && len(mapping.toMapKey) == 0 { + return fmt.Errorf("multiple mappings have a mapping to entire output, mapping= %s", mapping) } - if len(mapping.FromMapKey) > 0 { + if len(mapping.fromMapKey) > 0 { if fromFieldFlag { - return fmt.Errorf("multiple mappings from the same group have both from field and from map key, mappings=%v", mappings) + return fmt.Errorf("multiple mappings have both FromField and FromMapKey, mappings=%v", mappings) } - if _, ok := fromMap[mapping.FromMapKey]; ok { - return fmt.Errorf("multiple mappings from the same group have the same from map key = %s, mappings=%v", mapping.FromMapKey, mappings) + if _, ok := fromMap[mapping.fromMapKey]; ok { + return fmt.Errorf("multiple mappings have the same FromMapKey = %s, mappings=%v", mapping.fromMapKey, mappings) } fromMapKeyFlag = true - fromMap[mapping.FromMapKey] = true + fromMap[mapping.fromMapKey] = true } - if len(mapping.FromField) > 0 { + if len(mapping.fromField) > 0 { if fromMapKeyFlag { - return fmt.Errorf("multiple mappings from the same group have both from field and from map key, mappings=%v", mappings) + return fmt.Errorf("multiple mappings have both FromField and FromMapKey, mappings=%v", mappings) } - if _, ok := fromMap[mapping.FromField]; ok { - return fmt.Errorf("multiple mappings from the same group have the same from field = %s, mappings=%v", mapping.FromField, mappings) + if _, ok := fromMap[mapping.fromField]; ok { + return fmt.Errorf("multiple mappings have the same FromField = %s, mappings=%v", mapping.fromField, mappings) } fromFieldFlag = true - fromMap[mapping.FromField] = true + fromMap[mapping.fromField] = true } - if len(mapping.ToMapKey) > 0 { + if len(mapping.toMapKey) > 0 { if toFieldFlag { - return fmt.Errorf("multiple mappings from the same group have both to field and to map key, mappings=%v", mappings) + return fmt.Errorf("multiple mappings have both ToField and ToMapKey, mappings=%v", mappings) } - if _, ok := toMap[mapping.ToMapKey]; ok { - return fmt.Errorf("multiple mappings from the same group have the same to map key = %s, mappings=%v", mapping.ToMapKey, mappings) + if _, ok := toMap[mapping.toMapKey]; ok { + return fmt.Errorf("multiple mappings have the same ToMapKey = %s, mappings=%v", mapping.toMapKey, mappings) } toMapKeyFlag = true - toMap[mapping.ToMapKey] = true + toMap[mapping.toMapKey] = true } - if len(mapping.ToField) > 0 { + if len(mapping.toField) > 0 { if toMapKeyFlag { - return fmt.Errorf("multiple mappings from the same group have both to field and to map key, mappings=%v", mappings) + return fmt.Errorf("multiple mappings have both ToField and ToMapKey, mappings=%v", mappings) } - if _, ok := toMap[mapping.ToField]; ok { - return fmt.Errorf("multiple mappings from the same group have the same to field = %s, mappings=%v", mapping.ToField, mappings) + if _, ok := toMap[mapping.toField]; ok { + return fmt.Errorf("multiple mappings have the same ToField = %s, mappings=%v", mapping.toField, mappings) } toFieldFlag = true - toMap[mapping.ToField] = true + toMap[mapping.toField] = true } } diff --git a/compose/field_mapping_test.go b/compose/field_mapping_test.go index 9c55455..d26de66 100644 --- a/compose/field_mapping_test.go +++ b/compose/field_mapping_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package compose import ( @@ -6,15 +22,13 @@ import ( "testing" "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/utils/generic" ) func TestFieldMapping(t *testing.T) { t.Run("whole mapped to whole", func(t *testing.T) { - m := []*Mapping{ - { - From: "upper", - }, - } + m := []*Mapping{NewMapping("1")} out, err := mapFrom[string]("correct input", m) assert.NoError(t, err) @@ -27,13 +41,7 @@ func TestFieldMapping(t *testing.T) { _, err = mapFrom[string](1, m) assert.ErrorContains(t, err, "mismatched type") - m1 := []*Mapping{ - { - From: "upper", - }, - } - - out1, err := mapFrom[any]("correct input", m1) + out1, err := mapFrom[any]("correct input", m) assert.NoError(t, err) assert.Equal(t, "correct input", out1) }) @@ -45,12 +53,7 @@ func TestFieldMapping(t *testing.T) { F3 int } - m := []*Mapping{ - { - From: "upper", - FromField: "F1", - }, - } + m := []*Mapping{NewMapping("1").FromField("F1")} out, err := mapFrom[string](&up{F1: "field1"}, m) assert.NoError(t, err) @@ -64,37 +67,26 @@ func TestFieldMapping(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "field1", out) - m[0].FromField = "f2" + m[0].fromField = "f2" _, err = mapFrom[string](&up{f2: "f2"}, m) assert.ErrorContains(t, err, "not exported") - m[0].FromField = "field3" + m[0].fromField = "field3" _, err = mapFrom[string](&up{F3: 3}, m) - assert.ErrorContains(t, err, "FromField not found") + assert.ErrorContains(t, err, "fromField not found") - m[0].FromField = "F3" + m[0].fromField = "F3" _, err = mapFrom[string](&up{F3: 3}, m) assert.ErrorContains(t, err, "mismatched type") - m1 := []*Mapping{ - { - From: "upper", - FromField: "F1", - }, - } - - out1, err := mapFrom[any](&up{F1: "field1"}, m1) + m = []*Mapping{NewMapping("1").FromField("F1")} + out1, err := mapFrom[any](&up{F1: "field1"}, m) assert.NoError(t, err) assert.Equal(t, "field1", out1) }) t.Run("map key mapped to whole", func(t *testing.T) { - m := []*Mapping{ - { - From: "upper", - FromMapKey: "key1", - }, - } + m := []*Mapping{NewMapping("1").FromMapKey("key1")} out, err := mapFrom[string](map[string]string{"key1": "value1"}, m) assert.NoError(t, err) @@ -105,7 +97,7 @@ func TestFieldMapping(t *testing.T) { assert.Equal(t, "", out) out, err = mapFrom[string](map[string]string{"key2": "value2"}, m) - assert.ErrorContains(t, err, "FromMapKey not found") + assert.ErrorContains(t, err, "fromMapKey not found") out, err = mapFrom[string](map[string]int{"key1": 1}, m) assert.ErrorContains(t, err, "mismatched type") @@ -114,14 +106,7 @@ func TestFieldMapping(t *testing.T) { out, err = mapFrom[string](map[mock]string{"key1": "value1"}, m) assert.ErrorContains(t, err, "not a map with string key") - m1 := []*Mapping{ - { - From: "upper", - FromMapKey: "key1", - }, - } - - out1, err := mapFrom[any](map[string]string{"key1": "value1"}, m1) + out1, err := mapFrom[any](map[string]string{"key1": "value1"}, m) assert.NoError(t, err) assert.Equal(t, "value1", out1) }) @@ -132,12 +117,7 @@ func TestFieldMapping(t *testing.T) { f3 string } - m := []*Mapping{ - { - From: "upper", - ToField: "F1", - }, - } + m := []*Mapping{NewMapping("1").ToField("F1")} out, err := mapFrom[down]("from", m) assert.NoError(t, err) @@ -146,33 +126,22 @@ func TestFieldMapping(t *testing.T) { out, err = mapFrom[down](1, m) assert.ErrorContains(t, err, "mismatched type") - m[0].ToField = "f2" + m[0].toField = "f2" _, err = mapFrom[down]("from", m) - assert.ErrorContains(t, err, "ToField not found") + assert.ErrorContains(t, err, "toField not found") - m[0].ToField = "f3" + m[0].toField = "f3" _, err = mapFrom[down]("from", m) assert.ErrorContains(t, err, "not exported") - m1 := []*Mapping{ - { - From: "upper", - ToField: "F1", - }, - } - - out1, err := mapFrom[*down]("from", m1) + m = []*Mapping{NewMapping("1").ToField("F1")} + out1, err := mapFrom[*down]("from", m) assert.NoError(t, err) assert.Equal(t, &down{F1: "from"}, out1) }) t.Run("whole mapped to map key", func(t *testing.T) { - m := []*Mapping{ - { - From: "upper", - ToMapKey: "key1", - }, - } + m := []*Mapping{NewMapping("1").ToMapKey("key1")} out, err := mapFrom[map[string]string]("from", m) assert.NoError(t, err) @@ -181,15 +150,8 @@ func TestFieldMapping(t *testing.T) { assert.ErrorContains(t, err, "mismatched type") type mockKey string - m1 := []*Mapping{ - { - From: "upper", - ToMapKey: "key1", - }, - } - - _, err = mapFrom[map[mockKey]string]("from", m1) - assert.ErrorContains(t, err, "mapping has ToMapKey but output is not a map with string key") + _, err = mapFrom[map[mockKey]string]("from", m) + assert.ErrorContains(t, err, "mapping has toMapKey but output is not a map with string key") }) t.Run("field mapped to field", func(t *testing.T) { @@ -205,13 +167,7 @@ func TestFieldMapping(t *testing.T) { F1 *inner } - m := []*Mapping{ - { - From: "upper", - FromField: "F1", - ToField: "F1", - }, - } + m := []*Mapping{NewMapping("1").FromField("F1").ToField("F1")} out, err := mapFrom[*down](&up{F1: &inner{in: "in"}}, m) assert.NoError(t, err) @@ -223,13 +179,7 @@ func TestFieldMapping(t *testing.T) { F1 []string } - m := []*Mapping{ - { - From: "upper", - FromField: "F1", - ToMapKey: "key1", - }, - } + m := []*Mapping{NewMapping("1").FromField("F1").ToMapKey("key1")} out, err := mapFrom[map[string]any](&up{F1: []string{"in"}}, m) assert.NoError(t, err) @@ -237,13 +187,7 @@ func TestFieldMapping(t *testing.T) { }) t.Run("map key mapped to map key", func(t *testing.T) { - m := []*Mapping{ - { - From: "upper", - FromMapKey: "key1", - ToMapKey: "key2", - }, - } + m := []*Mapping{NewMapping("1").FromMapKey("key1").ToMapKey("key2")} out, err := mapFrom[map[string]any](map[string]any{"key1": "value1"}, m) assert.NoError(t, err) @@ -255,13 +199,7 @@ func TestFieldMapping(t *testing.T) { F1 io.Reader } - m := []*Mapping{ - { - From: "upper", - FromMapKey: "key1", - ToField: "F1", - }, - } + m := []*Mapping{NewMapping("1").FromMapKey("key1").ToField("F1")} out, err := mapFrom[*down](map[string]any{"key1": &bytes.Buffer{}}, m) assert.NoError(t, err) @@ -275,28 +213,20 @@ func TestFieldMapping(t *testing.T) { } m := []*Mapping{ - { - From: "upper", - FromMapKey: "key1", - ToField: "F1", - }, - { - From: "upper", - FromMapKey: "key2", - ToField: "F2", - }, + NewMapping("1").FromMapKey("key1").ToField("F1"), + NewMapping("1").FromMapKey("key2").ToField("F2"), } out, err := mapFrom[*down](map[string]any{"key1": "v1", "key2": 2}, m) assert.NoError(t, err) assert.Equal(t, &down{F1: "v1", F2: 2}, out) - m[0].From = "different_upper" + m[0].fromNodeKey = "different_upper" out, err = mapFrom[*down](map[string]any{"key1": "v1", "key2": 2}, m) assert.ErrorContains(t, err, "multiple mappings from the same node have different keys") - m[0].From = "upper" - m[0].ToField = "" + m[0].fromNodeKey = "1" + m[0].toField = "" out, err = mapFrom[*down](map[string]any{"key1": "v1", "key2": 2}, m) assert.ErrorContains(t, err, "one of the mapping maps to entire input, conflict") @@ -304,4 +234,30 @@ func TestFieldMapping(t *testing.T) { out, err = mapFrom[*down](map[string]any{"key1": "v1", "key2": 2}, m) assert.ErrorContains(t, err, "mapper has no Mappings") }) + + t.Run("invalid mapping", func(t *testing.T) { + m := []*Mapping{NewMapping("1").FromMapKey("key1").FromField("F1")} + _, err := mapFrom[string](map[string]any{"key1": "v1", "key2": 2}, m) + assert.ErrorContains(t, err, "mapping has both fromField and fromMapKey") + + m = []*Mapping{NewMapping("1").ToMapKey("key1").ToField("F1")} + _, err = mapFrom[string]("input", m) + assert.ErrorContains(t, err, "mapping has both toField and toMapKey") + + m = []*Mapping{NewMapping("1").FromField("F1")} + _, err = mapFrom[string](generic.PtrOf("input"), m) + assert.ErrorContains(t, err, "mapping has fromField but input is not struct or struct ptr") + + m = []*Mapping{NewMapping("1").FromMapKey("key1")} + _, err = mapFrom[string](generic.PtrOf("input"), m) + assert.ErrorContains(t, err, "mapping has FromKey but input is not a map") + + m = []*Mapping{NewMapping("1").ToField("F1")} + _, err = mapFrom[string]("input", m) + assert.ErrorContains(t, err, "mapping has toField but output is not a struct") + + m = []*Mapping{NewMapping("1").ToMapKey("key1")} + _, err = mapFrom[string]("input", m) + assert.ErrorContains(t, err, "mapping has toMapKey but output is not a map") + }) } diff --git a/compose/generic_graph.go b/compose/generic_graph.go index 7cc5ba8..ef30b92 100644 --- a/compose/generic_graph.go +++ b/compose/generic_graph.go @@ -140,5 +140,5 @@ func (g *Graph[I, O]) Compile(ctx context.Context, opts ...GraphCompileOption) ( } func (g *Graph[I, O]) fieldMapper() fieldMapper { - return defaultFieldMapper[O]{} + return defaultFieldMapper[I]{} } diff --git a/compose/graph.go b/compose/graph.go index e5a8318..03d50cf 100644 --- a/compose/graph.go +++ b/compose/graph.go @@ -618,28 +618,28 @@ func (g *graph) validateAndInferType(startNode, endNode string, mappings ...*Map fromType := startNodeOutputType toType := endNodeInputType - if len(m.FromMapKey) > 0 { - if fromType, err = checkAndExtractMapValueType(m.FromMapKey, fromType); err != nil { - return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%v] start node's output failed, %w", startNode, endNode, m, err) + if len(m.fromMapKey) > 0 { + if fromType, err = checkAndExtractMapValueType(fromType); err != nil { + return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%s] start node's output failed, %w", startNode, endNode, m, err) } - } else if len(m.FromField) > 0 { - if fromType, err = checkAndExtractFieldType(m.FromField, fromType); err != nil { - return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%v] start node's output failed, %w", startNode, endNode, m, err) + } else if len(m.fromField) > 0 { + if fromType, err = checkAndExtractFieldType(m.fromField, fromType); err != nil { + return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%s] start node's output failed, %w", startNode, endNode, m, err) } } - if len(m.ToMapKey) > 0 { - if toType, err = checkAndExtractMapValueType(m.ToMapKey, toType); err != nil { - return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%v] end node's input failed, %w", startNode, endNode, m, err) + if len(m.toMapKey) > 0 { + if toType, err = checkAndExtractMapValueType(toType); err != nil { + return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%s] end node's input failed, %w", startNode, endNode, m, err) } - } else if len(m.ToField) > 0 { - if toType, err = checkAndExtractFieldType(m.ToField, toType); err != nil { - return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%v] end node's input failed, %w", startNode, endNode, m, err) + } else if len(m.toField) > 0 { + if toType, err = checkAndExtractFieldType(m.toField, toType); err != nil { + return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%s] end node's input failed, %w", startNode, endNode, m, err) } } if checkAssignable(fromType, toType) == assignableTypeMustNot { - return fmt.Errorf("graph edge[%s]-[%s]: after mapping[%v], start node's output type[%s] and end node's input type[%s] mismatch", + return fmt.Errorf("graph edge[%s]-[%s]: after mapping[%s], start node's output type[%s] and end node's input type[%s] mismatch", m, startNode, endNode, fromType.String(), toType.String()) } } diff --git a/compose/stream_concat.go b/compose/stream_concat.go index 5383578..bc01326 100644 --- a/compose/stream_concat.go +++ b/compose/stream_concat.go @@ -299,5 +299,17 @@ func concatSliceValue(val reflect.Value) (reflect.Value, error) { return reflect.ValueOf(merged), nil } - return reflect.Value{}, fmt.Errorf("cannot concat value of type %s", elmType) + var filtered reflect.Value + for i := 0; i < val.Len(); i++ { + oneVal := val.Index(i) + if !oneVal.IsZero() { + if filtered.IsValid() { + return reflect.Value{}, fmt.Errorf("cannot concat multiple non-zero value of type %s", elmType) + } + + filtered = oneVal + } + } + + return filtered, nil } diff --git a/compose/workflow.go b/compose/workflow.go index 8821b42..11fa54a 100644 --- a/compose/workflow.go +++ b/compose/workflow.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package compose import ( @@ -5,6 +21,7 @@ import ( "errors" "fmt" "reflect" + "strings" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" @@ -17,17 +34,75 @@ import ( ) type Mapping struct { - From string + fromNodeKey string - FromField string - FromMapKey string + fromField string + fromMapKey string - ToField string - ToMapKey string + toField string + toMapKey string } func (m *Mapping) empty() bool { - return len(m.FromField) == 0 && len(m.FromMapKey) == 0 && len(m.ToField) == 0 && len(m.ToMapKey) == 0 + return len(m.fromField) == 0 && len(m.fromMapKey) == 0 && len(m.toField) == 0 && len(m.toMapKey) == 0 +} + +func (m *Mapping) FromField(fieldName string) *Mapping { + m.fromField = fieldName + return m +} + +func (m *Mapping) ToField(fieldName string) *Mapping { + m.toField = fieldName + return m +} + +func (m *Mapping) FromMapKey(mapKey string) *Mapping { + m.fromMapKey = mapKey + return m +} + +func (m *Mapping) ToMapKey(mapKey string) *Mapping { + m.toMapKey = mapKey + return m +} + +func (m *Mapping) String() string { + var sb strings.Builder + sb.WriteString("from ") + + if m.fromMapKey != "" { + sb.WriteString(m.fromMapKey) + sb.WriteString("(map key) of ") + } + + if m.fromField != "" { + sb.WriteString(m.fromField) + sb.WriteString("(field) of ") + } + + sb.WriteString("node '") + sb.WriteString(m.fromNodeKey) + sb.WriteString("'") + + if m.toField != "" { + sb.WriteString(" to ") + sb.WriteString(m.toField) + sb.WriteString("(field)") + } + + if m.toMapKey != "" { + sb.WriteString(" to ") + sb.WriteString(m.toMapKey) + sb.WriteString("(map key)") + } + + sb.WriteString("; ") + return sb.String() +} + +func NewMapping(fromNodeKey string) *Mapping { + return &Mapping{fromNodeKey: fromNodeKey} } type WorkflowNode struct { @@ -114,7 +189,7 @@ func convertAddNodeOpts(opts []WorkflowAddNodeOpt) []GraphAddNodeOpt { } func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.ChatModel, opts ...WorkflowAddNodeOpt) *WorkflowNode { - node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[*schema.Message]{}} + node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Message]{}} wf.nodes[key] = node if wf.err != nil { @@ -130,7 +205,7 @@ func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.ChatModel } func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.ChatTemplate, opts ...WorkflowAddNodeOpt) *WorkflowNode { - node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Message]{}} + node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[map[string]any]{}} wf.nodes[key] = node if wf.err != nil { @@ -147,7 +222,7 @@ func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.Ch } func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...WorkflowAddNodeOpt) *WorkflowNode { - node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Message]{}} + node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[*schema.Message]{}} wf.nodes[key] = node if wf.err != nil { @@ -163,7 +238,7 @@ func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...Wor } func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retriever, opts ...WorkflowAddNodeOpt) *WorkflowNode { - node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Document]{}} + node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[string]{}} wf.nodes[key] = node if wf.err != nil { @@ -179,7 +254,7 @@ func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retri } func (wf *Workflow[I, O]) AddEmbeddingNode(key string, embedding embedding.Embedder, opts ...WorkflowAddNodeOpt) *WorkflowNode { - node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[][]float64]{}} + node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]string]{}} wf.nodes[key] = node if wf.err != nil { @@ -195,7 +270,7 @@ func (wf *Workflow[I, O]) AddEmbeddingNode(key string, embedding embedding.Embed } func (wf *Workflow[I, O]) AddIndexerNode(key string, indexer indexer.Indexer, opts ...WorkflowAddNodeOpt) *WorkflowNode { - node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]string]{}} + node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Document]{}} wf.nodes[key] = node if wf.err != nil { @@ -211,7 +286,7 @@ func (wf *Workflow[I, O]) AddIndexerNode(key string, indexer indexer.Indexer, op } func (wf *Workflow[I, O]) AddLoaderNode(key string, loader document.Loader, opts ...WorkflowAddNodeOpt) *WorkflowNode { - node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Document]{}} + node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[document.Source]{}} wf.nodes[key] = node if wf.err != nil { @@ -280,7 +355,7 @@ func (n *WorkflowNode) AddInput(inputs ...*Mapping) *WorkflowNode { return n } -func (wf *Workflow[I, O]) AddEnd(inputs []*Mapping) { +func (wf *Workflow[I, O]) AddEnd(inputs ...*Mapping) { wf.end = inputs wf.endFieldMapper = defaultFieldMapper[O]{} } @@ -314,9 +389,6 @@ func (wf *Workflow[I, O]) addEdgesWithMapping() (err error) { for _, node := range wf.nodes { toNode = node.key fm := node.fieldMapper - if fm == nil { - return fmt.Errorf("workflow has no field mapper, node = %s", toNode) - } if len(node.inputs) == 0 { return fmt.Errorf("workflow node = %s has no input", toNode) @@ -325,7 +397,7 @@ func (wf *Workflow[I, O]) addEdgesWithMapping() (err error) { fromNode2Mappings := make(map[string][]*Mapping, len(node.inputs)) for i := range node.inputs { input := node.inputs[i] - fromNodeKey := input.From + fromNodeKey := input.fromNodeKey fromNode2Mappings[fromNodeKey] = append(fromNode2Mappings[fromNodeKey], input) } @@ -344,15 +416,16 @@ func (wf *Workflow[I, O]) addEdgesWithMapping() (err error) { } } - fm := wf.endFieldMapper - if fm == nil { - return errors.New("workflow has no end field mapper") + if len(wf.end) == 0 { + return errors.New("workflow END has no input mapping") } + fm := wf.endFieldMapper + fromNode2EndMappings := make(map[string][]*Mapping, len(wf.end)) for i := range wf.end { input := wf.end[i] - fromNodeKey := input.From + fromNodeKey := input.fromNodeKey fromNode2EndMappings[fromNodeKey] = append(fromNode2EndMappings[fromNodeKey], input) } diff --git a/compose/workflow_test.go b/compose/workflow_test.go index a60b66d..cdd250b 100644 --- a/compose/workflow_test.go +++ b/compose/workflow_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package compose import ( @@ -7,9 +23,20 @@ import ( "testing" "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/internal/mock/components/document" + "github.com/cloudwego/eino/internal/mock/components/embedding" + "github.com/cloudwego/eino/internal/mock/components/indexer" + "github.com/cloudwego/eino/internal/mock/components/model" + "github.com/cloudwego/eino/internal/mock/components/retriever" + "github.com/cloudwego/eino/schema" ) func TestWorkflow(t *testing.T) { + ctx := context.Background() + type structA struct { Field1 string Field2 int @@ -27,66 +54,111 @@ func TestWorkflow(t *testing.T) { Field3 []any } - w := NewWorkflow[*structA, string]() + type state struct { + temp string + } - w. - AddLambdaNode( - "B", - InvokableLambda(func(context.Context, string) (*structB, error) { - return &structB{Field1: "1", Field2: 2}, nil - }), - WithWorkflowNodeName("node B")). - AddInput(&Mapping{ - From: START, - FromField: "Field1", - }) + subGraph := NewGraph[string, *structB]() + _ = subGraph.AddLambdaNode( + "1", + InvokableLambda(func(ctx context.Context, input string) (*structB, error) { + return &structB{Field1: input, Field2: 33}, nil + }), + ) + _ = subGraph.AddEdge(START, "1") + _ = subGraph.AddEdge("1", END) + + subChain := NewChain[any, any](). + AppendLambda(InvokableLambda(func(_ context.Context, in any) (any, error) { + return map[string]string{"key2": fmt.Sprintf("%d", in)}, nil + })) + + subWorkflow := NewWorkflow[[]any, []any]() + subWorkflow.AddLambdaNode( + "1", + InvokableLambda(func(_ context.Context, in []any) ([]any, error) { + return in, nil + })). + AddInput(NewMapping(START)) + subWorkflow.AddEnd(NewMapping("1")) + + w := NewWorkflow[*structA, string](WithGenLocalState(func(context.Context) *state { return &state{} })) w. - AddLambdaNode( - "C", - InvokableLambda(func(_ context.Context, in int) (map[string]string, error) { - return map[string]string{"key2": fmt.Sprintf("%d", in)}, nil + AddGraphNode("B", subGraph, WithWorkflowNodeName("node B"), + WithWorkflowStatePostHandler(func(ctx context.Context, out *structB, state *state) (*structB, error) { + state.temp = out.Field1 + return out, nil })). - AddInput(&Mapping{ - From: START, - FromField: "Field2", - }) + AddInput(NewMapping(START).FromField("Field1")) w. - AddLambdaNode( - "D", - InvokableLambda(func(_ context.Context, in []any) ([]any, error) { - return in, nil - })). - AddInput(&Mapping{ - From: START, - FromField: "Field3", - }) + AddGraphNode("C", subChain). + AddInput(NewMapping(START).FromField("Field2")) + + w. + AddGraphNode("D", subWorkflow). + AddInput(NewMapping(START).FromField("Field3")) w. AddLambdaNode( "E", - InvokableLambda(func(_ context.Context, in *structE) (string, error) { - return in.Field1 + "_" + in.Field2 + "_" + fmt.Sprintf("%v", in.Field3[0]), nil + TransformableLambda(func(_ context.Context, in *schema.StreamReader[structE]) (*schema.StreamReader[structE], error) { + return schema.StreamReaderWithConvert(in, func(in structE) (structE, error) { + if len(in.Field1) > 0 { + in.Field1 = "E:" + in.Field1 + } + if len(in.Field2) > 0 { + in.Field2 = "E:" + in.Field2 + } + + return in, nil + }), nil + }), + WithWorkflowStreamStatePreHandler(func(ctx context.Context, in *schema.StreamReader[structE], state *state) (*schema.StreamReader[structE], error) { + temp := state.temp + return schema.StreamReaderWithConvert(in, func(v structE) (structE, error) { + if len(v.Field3) > 0 { + v.Field3 = append(v.Field3, "Pre:"+temp) + } + + return v, nil + }), nil + }), + WithWorkflowStreamStatePostHandler(func(ctx context.Context, out *schema.StreamReader[structE], state *state) (*schema.StreamReader[structE], error) { + return schema.StreamReaderWithConvert(out, func(v structE) (structE, error) { + if len(v.Field1) > 0 { + v.Field1 = v.Field1 + "+Post" + } + return v, nil + }), nil })). - AddInput(&Mapping{ - From: "B", - FromField: "Field1", - ToField: "Field1", - }, &Mapping{ - From: "C", - FromMapKey: "key2", - ToField: "Field2", - }, &Mapping{ - From: "D", - ToField: "Field3", - }) - - w.AddEnd([]*Mapping{{ - From: "E", - }}) + AddInput( + NewMapping("B").FromField("Field1").ToField("Field1"), + NewMapping("C").FromMapKey("key2").ToField("Field2"), + NewMapping("D").ToField("Field3"), + ) + + w. + AddLambdaNode( + "F", + InvokableLambda(func(ctx context.Context, in map[string]any) (string, error) { + return fmt.Sprintf("%v_%v_%v_%v_%v", in["key1"], in["key2"], in["key3"], in["B"], in["state_temp"]), nil + }), + WithWorkflowStatePreHandler(func(ctx context.Context, in map[string]any, state *state) (map[string]any, error) { + in["state_temp"] = state.temp + return in, nil + }), + ). + AddInput( + NewMapping("B").FromField("Field2").ToMapKey("B"), + NewMapping("E").FromField("Field1").ToMapKey("key1"), + NewMapping("E").FromField("Field2").ToMapKey("key2"), + NewMapping("E").FromField("Field3").ToMapKey("key3"), + ) + + w.AddEnd(NewMapping("F")) - ctx := context.Background() compiled, err := w.Compile(ctx) assert.NoError(t, err) @@ -99,7 +171,7 @@ func TestWorkflow(t *testing.T) { } out, err := compiled.Invoke(ctx, input) assert.NoError(t, err) - assert.Equal(t, "1_2_1", out) + assert.Equal(t, "E:1+Post_E:2_[1 good Pre:1]_33_1", out) outStream, err := compiled.Stream(ctx, input) assert.NoError(t, err) @@ -115,6 +187,236 @@ func TestWorkflow(t *testing.T) { return } - assert.Equal(t, "1_2_1", chunk) + assert.Equal(t, "E:1+Post_E:2_[1 good Pre:1]_33_1", chunk) } } + +func TestWorkflowCompile(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + + t.Run("a node has no input", func(t *testing.T) { + w := NewWorkflow[string, string]() + w.AddToolsNode("1", &ToolsNode{}) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "workflow node = 1 has no input") + }) + + t.Run("compile without add end", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput(NewMapping(START)) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "workflow END has no input mapping") + }) + + t.Run("type mismatch", func(t *testing.T) { + w := NewWorkflow[string, string]() + w.AddToolsNode("1", &ToolsNode{}).AddInput(NewMapping(START)) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "mismatch") + }) + + t.Run("upstream not map, mapping has FromMapKey", func(t *testing.T) { + w := NewWorkflow[map[string]any, string]() + + w.AddChatTemplateNode("prompt", prompt.FromMessages(schema.Jinja2, schema.SystemMessage("your name is {{ name }}"))).AddInput(NewMapping(START)) + + w.AddEnd(NewMapping("prompt").FromMapKey("key1")) + + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[[]*schema.Message] is not a map") + }) + + t.Run("upstream not struct/struct ptr, mapping has FromField", func(t *testing.T) { + w := NewWorkflow[[]*schema.Document, []string]() + + w.AddIndexerNode("indexer", indexer.NewMockIndexer(ctrl)).AddInput(NewMapping(START).FromField("F1")) + w.AddEnd(NewMapping("indexer")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[[]*schema.Document] is not a struct") + }) + + t.Run("downstream not map, mapping has ToMapKey", func(t *testing.T) { + w := NewWorkflow[string, []*schema.Document]() + + w.AddRetrieverNode("retriever", retriever.NewMockRetriever(ctrl)).AddInput(NewMapping(START).ToMapKey("key1")) + w.AddEnd(NewMapping("retriever")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[string] is not a map") + }) + + t.Run("downstream not struct/struct ptr, mapping has ToField", func(t *testing.T) { + w := NewWorkflow[[]string, [][]float64]() + w.AddEmbeddingNode("embedder", embedding.NewMockEmbedder(ctrl)).AddInput(NewMapping(START).ToField("F1")) + w.AddEnd(NewMapping("embedder")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[[]string] is not a struct") + }) + + t.Run("upstream not map[string]T, mapping has FromField", func(t *testing.T) { + w := NewWorkflow[map[int]int, any]() + w.AddLoaderNode("loader", document.NewMockLoader(ctrl)).AddInput(NewMapping(START).FromMapKey("key1")) + w.AddEnd(NewMapping("loader")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[map[int]int] is not a map with string key") + }) + + t.Run("downstream not map[string]T, mapping has ToField", func(t *testing.T) { + w := NewWorkflow[any, map[int]int]() + w.AddDocumentTransformerNode("transformer", document.NewMockTransformer(ctrl)).AddInput(NewMapping(START)) + w.AddEnd(NewMapping("transformer").ToMapKey("key1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[map[int]int] is not a map with string key") + }) + + t.Run("map to non existing field in upstream", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("tools_node", &ToolsNode{}).AddInput(NewMapping(START).FromField("non_exist")) + w.AddEnd(NewMapping("tools_node")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[schema.Message] has no field[non_exist]") + }) + + t.Run("map to not exported field in downstream", func(t *testing.T) { + w := NewWorkflow[string, *Mapping]() + w.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input, nil + })).AddInput(NewMapping(START)) + w.AddEnd(NewMapping("1").ToField("toField")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[compose.Mapping] has an unexported field[toField]") + }) + + t.Run("duplicate node key", func(t *testing.T) { + w := NewWorkflow[[]*schema.Message, []*schema.Message]() + w.AddChatModelNode("1", model.NewMockChatModel(ctrl)).AddInput(NewMapping(START)) + w.AddToolsNode("1", &ToolsNode{}).AddInput(NewMapping("1")) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "node '1' already present") + }) + + t.Run("from non-existing node", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput(NewMapping(START)) + w.AddEnd(NewMapping("2")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "edge start node '2' needs to be added to graph first") + }) + + t.Run("multiple mappings have an empty mapping", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput(NewMapping(START), NewMapping(START).FromField("Content").ToField("Content")) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "have an empty mapping") + }) + + t.Run("multiple mappings have mapping to entire output ", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput( + NewMapping(START).FromField("Role"), + NewMapping(START).FromField("Content"), + ) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "have a mapping to entire output") + }) + + t.Run("multiple mappings have mapping from entire input ", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput( + NewMapping(START).ToField("Role"), + NewMapping(START).ToField("Content"), + ) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "have a mapping from entire input") + }) + + t.Run("multiple mappings set both FromMapKey and FromField", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput( + NewMapping(START).FromMapKey("Role").ToField("Role"), + NewMapping(START).FromField("Content").ToField("Content"), + ) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "have both FromField and FromMapKey") + + w = NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput( + NewMapping(START).FromField("Content").ToField("Content"), + NewMapping(START).FromMapKey("Role").ToField("Role"), + ) + w.AddEnd(NewMapping("1")) + _, err = w.Compile(ctx) + assert.ErrorContains(t, err, "have both FromField and FromMapKey") + }) + + t.Run("multiple mappings have duplicate FromMapKey", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput( + NewMapping(START).FromMapKey("Role").ToField("Role"), + NewMapping(START).FromMapKey("Role").ToField("Role"), + ) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "have the same FromMapKey") + }) + + t.Run("multiple mappings have duplicate FromField", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput( + NewMapping(START).FromField("Content").ToField("Content"), + NewMapping(START).FromField("Content").ToField("Content"), + ) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "have the same FromField") + }) + + t.Run("multiple mappings set both ToMapKey and ToField", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput( + NewMapping(START).FromField("Role").ToField("Role"), + NewMapping(START).FromField("Content").ToMapKey("Content"), + ) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "have both ToField and ToMapKey") + + w = NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput( + NewMapping(START).FromField("Content").ToMapKey("Content"), + NewMapping(START).FromField("Role").ToField("Role"), + ) + w.AddEnd(NewMapping("1")) + _, err = w.Compile(ctx) + assert.ErrorContains(t, err, "have both ToField and ToMapKey") + }) + + t.Run("multiple mappings have duplicate ToMapKey", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput( + NewMapping(START).FromMapKey("Role").ToMapKey("Role"), + NewMapping(START).FromMapKey("Content").ToMapKey("Role"), + ) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "have the same ToMapKey") + }) + + t.Run("multiple mappings have duplicate ToField", func(t *testing.T) { + w := NewWorkflow[*schema.Message, []*schema.Message]() + w.AddToolsNode("1", &ToolsNode{}).AddInput( + NewMapping(START).FromField("Content").ToField("Content"), + NewMapping(START).FromField("Role").ToField("Content"), + ) + w.AddEnd(NewMapping("1")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "have the same ToField") + }) +}