From 7a31f1cfbc7d8d5d320b1da6a69a666ccc3784cb Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Mon, 6 Jan 2025 17:01:12 +0800 Subject: [PATCH] fixes and tests Change-Id: I9c7e952df2d8c4d30d149ca97a267672eacbea4a --- compose/field_mapping.go | 110 ++++++++++--------- compose/field_mapping_test.go | 192 +++++++++++++--------------------- compose/generic_graph.go | 2 +- compose/graph.go | 26 ++--- compose/workflow.go | 116 ++++++++++++++++---- compose/workflow_test.go | 122 +++++++++++++++------ 6 files changed, 335 insertions(+), 233 deletions(-) diff --git a/compose/field_mapping.go b/compose/field_mapping.go index 49c6f15..2beb0ec 100644 --- a/compose/field_mapping.go +++ b/compose/field_mapping.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package compose import ( @@ -10,18 +26,18 @@ import ( ) func takeOne(input any, m *Mapping) (any, error) { - if len(m.FromField) == 0 && len(m.FromMapKey) == 0 { + if len(m.fromField) == 0 && len(m.fromMapKey) == 0 { return input, nil } - if len(m.FromField) > 0 && len(m.FromMapKey) > 0 { - return nil, fmt.Errorf("mapping has both FromField and FromMapKey, m=%+v", m) + if len(m.fromField) > 0 && len(m.fromMapKey) > 0 { + return nil, fmt.Errorf("mapping has both fromField and fromMapKey, m=%s", m) } inputValue := reflect.ValueOf(input) - if len(m.FromField) > 0 { - f, err := checkAndExtractFromField(m.FromField, inputValue) + if len(m.fromField) > 0 { + f, err := checkAndExtractFromField(m.fromField, inputValue) if err != nil { return nil, err } @@ -29,7 +45,7 @@ func takeOne(input any, m *Mapping) (any, error) { return f.Interface(), nil } - v, err := checkAndExtractFromMapKey(m.FromMapKey, inputValue) + v, err := checkAndExtractFromMapKey(m.fromMapKey, inputValue) if err != nil { return nil, err } @@ -44,7 +60,7 @@ func assignOne[T any](dest T, taken any, m *Mapping) (T, error) { destValue = reflect.ValueOf(&dest).Elem() } - if len(m.ToField) == 0 && len(m.ToMapKey) == 0 { // assign to output directly + if len(m.toField) == 0 && len(m.toMapKey) == 0 { // assign to output directly toSet := reflect.ValueOf(taken) if !toSet.Type().AssignableTo(destValue.Type()) { return dest, fmt.Errorf("mapping entire value has a mismatched type. from=%v, to=%v", toSet.Type(), destValue.Type()) @@ -55,14 +71,14 @@ func assignOne[T any](dest T, taken any, m *Mapping) (T, error) { return destValue.Interface().(T), nil } - if len(m.ToField) > 0 && len(m.ToMapKey) > 0 { - return dest, fmt.Errorf("mapping has both ToField and ToMapKey, m=%+v", m) + if len(m.toField) > 0 && len(m.toMapKey) > 0 { + return dest, fmt.Errorf("mapping has both toField and toMapKey, m=%s", m) } toSet := reflect.ValueOf(taken) - if len(m.ToField) > 0 { - field, err := checkAndExtractToField(m.ToField, destValue, toSet) + if len(m.toField) > 0 { + field, err := checkAndExtractToField(m.toField, destValue, toSet) if err != nil { return dest, err } @@ -71,7 +87,7 @@ func assignOne[T any](dest T, taken any, m *Mapping) (T, error) { return destValue.Interface().(T), nil } - key, err := checkAndExtractToMapKey(m.ToMapKey, destValue, toSet) + key, err := checkAndExtractToMapKey(m.toMapKey, destValue, toSet) if err != nil { return dest, err } @@ -87,16 +103,16 @@ func mapFrom[T any](input any, mappings []*Mapping) (T, error) { return t, errors.New("mapper has no Mappings") } - from := mappings[0].From + from := mappings[0].fromNodeKey for _, mapping := range mappings { - if len(mapping.ToField) == 0 && len(mapping.ToMapKey) == 0 { + if len(mapping.toField) == 0 && len(mapping.toMapKey) == 0 { if len(mappings) > 1 { return t, fmt.Errorf("one of the mapping maps to entire input, conflict") } } - if mapping.From != from { - return t, fmt.Errorf("multiple mappings from the same node have different keys: %s, %s", mapping.From, from) + if mapping.fromNodeKey != from { + return t, fmt.Errorf("multiple mappings from the same node have different keys: %s, %s", mapping.fromNodeKey, from) } taken, err := takeOne(input, mapping) @@ -150,16 +166,16 @@ func checkAndExtractFromField(fromField string, input reflect.Value) (reflect.Va } if input.Kind() != reflect.Struct { - return reflect.Value{}, fmt.Errorf("mapping has FromField but input is not struct or struct ptr, type= %v", input.Type()) + return reflect.Value{}, fmt.Errorf("mapping has fromField but input is not struct or struct ptr, type= %v", input.Type()) } f := input.FieldByName(fromField) if !f.IsValid() { - return reflect.Value{}, fmt.Errorf("mapping has FromField not found. field=%v, inputType=%v", fromField, input.Type()) + return reflect.Value{}, fmt.Errorf("mapping has fromField not found. field=%v, inputType=%v", fromField, input.Type()) } if !f.CanInterface() { - return reflect.Value{}, fmt.Errorf("mapping has FromField not exported. field= %v, inputType=%v", fromField, input.Type()) + return reflect.Value{}, fmt.Errorf("mapping has fromField not exported. field= %v, inputType=%v", fromField, input.Type()) } return f, nil @@ -176,7 +192,7 @@ func checkAndExtractFromMapKey(fromMapKey string, input reflect.Value) (reflect. v := input.MapIndex(reflect.ValueOf(fromMapKey)) if !v.IsValid() { - return reflect.Value{}, fmt.Errorf("mapping FromMapKey not found in input. key=%s, inputType= %v", fromMapKey, input.Type()) + return reflect.Value{}, fmt.Errorf("mapping fromMapKey not found in input. key=%s, inputType= %v", fromMapKey, input.Type()) } return v, nil @@ -188,20 +204,20 @@ func checkAndExtractToField(toField string, output, toSet reflect.Value) (reflec } if output.Kind() != reflect.Struct { - return reflect.Value{}, fmt.Errorf("mapping has ToField but output is not a struct, type=%v", output.Type()) + return reflect.Value{}, fmt.Errorf("mapping has toField but output is not a struct, type=%v", output.Type()) } field := output.FieldByName(toField) if !field.IsValid() { - return reflect.Value{}, fmt.Errorf("mapping has ToField not found. field=%v, outputType=%v", toField, output.Type()) + return reflect.Value{}, fmt.Errorf("mapping has toField not found. field=%v, outputType=%v", toField, output.Type()) } if !field.CanSet() { - return reflect.Value{}, fmt.Errorf("mapping has ToField not exported. field=%v, outputType=%v", toField, output.Type()) + return reflect.Value{}, fmt.Errorf("mapping has toField not exported. field=%v, outputType=%v", toField, output.Type()) } if !toSet.Type().AssignableTo(field.Type()) { - return reflect.Value{}, fmt.Errorf("mapping ToField has a mismatched type. field=%s, from=%v, to=%v", toField, toSet.Type(), field.Type()) + return reflect.Value{}, fmt.Errorf("mapping toField has a mismatched type. field=%s, from=%v, to=%v", toField, toSet.Type(), field.Type()) } return field, nil @@ -209,21 +225,21 @@ func checkAndExtractToField(toField string, output, toSet reflect.Value) (reflec func checkAndExtractToMapKey(toMapKey string, output, toSet reflect.Value) (reflect.Value, error) { if output.Kind() != reflect.Map { - return reflect.Value{}, fmt.Errorf("mapping has ToMapKey but output is not a map, type=%v", output.Type()) + return reflect.Value{}, fmt.Errorf("mapping has toMapKey but output is not a map, type=%v", output.Type()) } if !reflect.TypeOf(toMapKey).AssignableTo(output.Type().Key()) { - return reflect.Value{}, fmt.Errorf("mapping has ToMapKey but output is not a map with string key, type=%v", output.Type()) + return reflect.Value{}, fmt.Errorf("mapping has toMapKey but output is not a map with string key, type=%v", output.Type()) } if !toSet.Type().AssignableTo(output.Type().Elem()) { - return reflect.Value{}, fmt.Errorf("mapping ToMapKey has a mismatched type. key=%s, from=%v, to=%v", toMapKey, toSet.Type(), output.Type().Elem()) + return reflect.Value{}, fmt.Errorf("mapping toMapKey has a mismatched type. key=%s, from=%v, to=%v", toMapKey, toSet.Type(), output.Type().Elem()) } return reflect.ValueOf(toMapKey), nil } -func checkAndExtractMapValueType(mapKey string, typ reflect.Type) (reflect.Type, error) { +func checkAndExtractMapValueType(typ reflect.Type) (reflect.Type, error) { if typ.Kind() != reflect.Map { return nil, fmt.Errorf("type[%v] is not a map", typ) } @@ -261,65 +277,65 @@ func checkMappingGroup(mappings []*Mapping) error { return nil } - from := mappings[0].From + from := mappings[0].fromNodeKey var fromMapKeyFlag, toMapKeyFlag, fromFieldFlag, toFieldFlag bool fromMap := make(map[string]bool, len(mappings)) toMap := make(map[string]bool, len(mappings)) for _, mapping := range mappings { - if mapping.From != from { - return fmt.Errorf("multiple mappings from the same group have from node keys: %s, %s", mapping.From, from) + if mapping.fromNodeKey != from { + return fmt.Errorf("multiple mappings from the same group have from node keys: %s, %s", mapping.fromNodeKey, from) } - if len(mapping.FromMapKey) > 0 { + if len(mapping.fromMapKey) > 0 { if fromFieldFlag { return fmt.Errorf("multiple mappings from the same group have both from field and from map key, mappings=%v", mappings) } - if _, ok := fromMap[mapping.FromMapKey]; ok { - return fmt.Errorf("multiple mappings from the same group have the same from map key = %s, mappings=%v", mapping.FromMapKey, mappings) + if _, ok := fromMap[mapping.fromMapKey]; ok { + return fmt.Errorf("multiple mappings from the same group have the same from map key = %s, mappings=%v", mapping.fromMapKey, mappings) } fromMapKeyFlag = true - fromMap[mapping.FromMapKey] = true + fromMap[mapping.fromMapKey] = true } - if len(mapping.FromField) > 0 { + if len(mapping.fromField) > 0 { if fromMapKeyFlag { return fmt.Errorf("multiple mappings from the same group have both from field and from map key, mappings=%v", mappings) } - if _, ok := fromMap[mapping.FromField]; ok { - return fmt.Errorf("multiple mappings from the same group have the same from field = %s, mappings=%v", mapping.FromField, mappings) + if _, ok := fromMap[mapping.fromField]; ok { + return fmt.Errorf("multiple mappings from the same group have the same from field = %s, mappings=%v", mapping.fromField, mappings) } fromFieldFlag = true - fromMap[mapping.FromField] = true + fromMap[mapping.fromField] = true } - if len(mapping.ToMapKey) > 0 { + if len(mapping.toMapKey) > 0 { if toFieldFlag { return fmt.Errorf("multiple mappings from the same group have both to field and to map key, mappings=%v", mappings) } - if _, ok := toMap[mapping.ToMapKey]; ok { - return fmt.Errorf("multiple mappings from the same group have the same to map key = %s, mappings=%v", mapping.ToMapKey, mappings) + if _, ok := toMap[mapping.toMapKey]; ok { + return fmt.Errorf("multiple mappings from the same group have the same to map key = %s, mappings=%v", mapping.toMapKey, mappings) } toMapKeyFlag = true - toMap[mapping.ToMapKey] = true + toMap[mapping.toMapKey] = true } - if len(mapping.ToField) > 0 { + if len(mapping.toField) > 0 { if toMapKeyFlag { return fmt.Errorf("multiple mappings from the same group have both to field and to map key, mappings=%v", mappings) } - if _, ok := toMap[mapping.ToField]; ok { - return fmt.Errorf("multiple mappings from the same group have the same to field = %s, mappings=%v", mapping.ToField, mappings) + if _, ok := toMap[mapping.toField]; ok { + return fmt.Errorf("multiple mappings from the same group have the same to field = %s, mappings=%v", mapping.toField, mappings) } toFieldFlag = true - toMap[mapping.ToField] = true + toMap[mapping.toField] = true } } diff --git a/compose/field_mapping_test.go b/compose/field_mapping_test.go index 9c55455..d26de66 100644 --- a/compose/field_mapping_test.go +++ b/compose/field_mapping_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package compose import ( @@ -6,15 +22,13 @@ import ( "testing" "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/utils/generic" ) func TestFieldMapping(t *testing.T) { t.Run("whole mapped to whole", func(t *testing.T) { - m := []*Mapping{ - { - From: "upper", - }, - } + m := []*Mapping{NewMapping("1")} out, err := mapFrom[string]("correct input", m) assert.NoError(t, err) @@ -27,13 +41,7 @@ func TestFieldMapping(t *testing.T) { _, err = mapFrom[string](1, m) assert.ErrorContains(t, err, "mismatched type") - m1 := []*Mapping{ - { - From: "upper", - }, - } - - out1, err := mapFrom[any]("correct input", m1) + out1, err := mapFrom[any]("correct input", m) assert.NoError(t, err) assert.Equal(t, "correct input", out1) }) @@ -45,12 +53,7 @@ func TestFieldMapping(t *testing.T) { F3 int } - m := []*Mapping{ - { - From: "upper", - FromField: "F1", - }, - } + m := []*Mapping{NewMapping("1").FromField("F1")} out, err := mapFrom[string](&up{F1: "field1"}, m) assert.NoError(t, err) @@ -64,37 +67,26 @@ func TestFieldMapping(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "field1", out) - m[0].FromField = "f2" + m[0].fromField = "f2" _, err = mapFrom[string](&up{f2: "f2"}, m) assert.ErrorContains(t, err, "not exported") - m[0].FromField = "field3" + m[0].fromField = "field3" _, err = mapFrom[string](&up{F3: 3}, m) - assert.ErrorContains(t, err, "FromField not found") + assert.ErrorContains(t, err, "fromField not found") - m[0].FromField = "F3" + m[0].fromField = "F3" _, err = mapFrom[string](&up{F3: 3}, m) assert.ErrorContains(t, err, "mismatched type") - m1 := []*Mapping{ - { - From: "upper", - FromField: "F1", - }, - } - - out1, err := mapFrom[any](&up{F1: "field1"}, m1) + m = []*Mapping{NewMapping("1").FromField("F1")} + out1, err := mapFrom[any](&up{F1: "field1"}, m) assert.NoError(t, err) assert.Equal(t, "field1", out1) }) t.Run("map key mapped to whole", func(t *testing.T) { - m := []*Mapping{ - { - From: "upper", - FromMapKey: "key1", - }, - } + m := []*Mapping{NewMapping("1").FromMapKey("key1")} out, err := mapFrom[string](map[string]string{"key1": "value1"}, m) assert.NoError(t, err) @@ -105,7 +97,7 @@ func TestFieldMapping(t *testing.T) { assert.Equal(t, "", out) out, err = mapFrom[string](map[string]string{"key2": "value2"}, m) - assert.ErrorContains(t, err, "FromMapKey not found") + assert.ErrorContains(t, err, "fromMapKey not found") out, err = mapFrom[string](map[string]int{"key1": 1}, m) assert.ErrorContains(t, err, "mismatched type") @@ -114,14 +106,7 @@ func TestFieldMapping(t *testing.T) { out, err = mapFrom[string](map[mock]string{"key1": "value1"}, m) assert.ErrorContains(t, err, "not a map with string key") - m1 := []*Mapping{ - { - From: "upper", - FromMapKey: "key1", - }, - } - - out1, err := mapFrom[any](map[string]string{"key1": "value1"}, m1) + out1, err := mapFrom[any](map[string]string{"key1": "value1"}, m) assert.NoError(t, err) assert.Equal(t, "value1", out1) }) @@ -132,12 +117,7 @@ func TestFieldMapping(t *testing.T) { f3 string } - m := []*Mapping{ - { - From: "upper", - ToField: "F1", - }, - } + m := []*Mapping{NewMapping("1").ToField("F1")} out, err := mapFrom[down]("from", m) assert.NoError(t, err) @@ -146,33 +126,22 @@ func TestFieldMapping(t *testing.T) { out, err = mapFrom[down](1, m) assert.ErrorContains(t, err, "mismatched type") - m[0].ToField = "f2" + m[0].toField = "f2" _, err = mapFrom[down]("from", m) - assert.ErrorContains(t, err, "ToField not found") + assert.ErrorContains(t, err, "toField not found") - m[0].ToField = "f3" + m[0].toField = "f3" _, err = mapFrom[down]("from", m) assert.ErrorContains(t, err, "not exported") - m1 := []*Mapping{ - { - From: "upper", - ToField: "F1", - }, - } - - out1, err := mapFrom[*down]("from", m1) + m = []*Mapping{NewMapping("1").ToField("F1")} + out1, err := mapFrom[*down]("from", m) assert.NoError(t, err) assert.Equal(t, &down{F1: "from"}, out1) }) t.Run("whole mapped to map key", func(t *testing.T) { - m := []*Mapping{ - { - From: "upper", - ToMapKey: "key1", - }, - } + m := []*Mapping{NewMapping("1").ToMapKey("key1")} out, err := mapFrom[map[string]string]("from", m) assert.NoError(t, err) @@ -181,15 +150,8 @@ func TestFieldMapping(t *testing.T) { assert.ErrorContains(t, err, "mismatched type") type mockKey string - m1 := []*Mapping{ - { - From: "upper", - ToMapKey: "key1", - }, - } - - _, err = mapFrom[map[mockKey]string]("from", m1) - assert.ErrorContains(t, err, "mapping has ToMapKey but output is not a map with string key") + _, err = mapFrom[map[mockKey]string]("from", m) + assert.ErrorContains(t, err, "mapping has toMapKey but output is not a map with string key") }) t.Run("field mapped to field", func(t *testing.T) { @@ -205,13 +167,7 @@ func TestFieldMapping(t *testing.T) { F1 *inner } - m := []*Mapping{ - { - From: "upper", - FromField: "F1", - ToField: "F1", - }, - } + m := []*Mapping{NewMapping("1").FromField("F1").ToField("F1")} out, err := mapFrom[*down](&up{F1: &inner{in: "in"}}, m) assert.NoError(t, err) @@ -223,13 +179,7 @@ func TestFieldMapping(t *testing.T) { F1 []string } - m := []*Mapping{ - { - From: "upper", - FromField: "F1", - ToMapKey: "key1", - }, - } + m := []*Mapping{NewMapping("1").FromField("F1").ToMapKey("key1")} out, err := mapFrom[map[string]any](&up{F1: []string{"in"}}, m) assert.NoError(t, err) @@ -237,13 +187,7 @@ func TestFieldMapping(t *testing.T) { }) t.Run("map key mapped to map key", func(t *testing.T) { - m := []*Mapping{ - { - From: "upper", - FromMapKey: "key1", - ToMapKey: "key2", - }, - } + m := []*Mapping{NewMapping("1").FromMapKey("key1").ToMapKey("key2")} out, err := mapFrom[map[string]any](map[string]any{"key1": "value1"}, m) assert.NoError(t, err) @@ -255,13 +199,7 @@ func TestFieldMapping(t *testing.T) { F1 io.Reader } - m := []*Mapping{ - { - From: "upper", - FromMapKey: "key1", - ToField: "F1", - }, - } + m := []*Mapping{NewMapping("1").FromMapKey("key1").ToField("F1")} out, err := mapFrom[*down](map[string]any{"key1": &bytes.Buffer{}}, m) assert.NoError(t, err) @@ -275,28 +213,20 @@ func TestFieldMapping(t *testing.T) { } m := []*Mapping{ - { - From: "upper", - FromMapKey: "key1", - ToField: "F1", - }, - { - From: "upper", - FromMapKey: "key2", - ToField: "F2", - }, + NewMapping("1").FromMapKey("key1").ToField("F1"), + NewMapping("1").FromMapKey("key2").ToField("F2"), } out, err := mapFrom[*down](map[string]any{"key1": "v1", "key2": 2}, m) assert.NoError(t, err) assert.Equal(t, &down{F1: "v1", F2: 2}, out) - m[0].From = "different_upper" + m[0].fromNodeKey = "different_upper" out, err = mapFrom[*down](map[string]any{"key1": "v1", "key2": 2}, m) assert.ErrorContains(t, err, "multiple mappings from the same node have different keys") - m[0].From = "upper" - m[0].ToField = "" + m[0].fromNodeKey = "1" + m[0].toField = "" out, err = mapFrom[*down](map[string]any{"key1": "v1", "key2": 2}, m) assert.ErrorContains(t, err, "one of the mapping maps to entire input, conflict") @@ -304,4 +234,30 @@ func TestFieldMapping(t *testing.T) { out, err = mapFrom[*down](map[string]any{"key1": "v1", "key2": 2}, m) assert.ErrorContains(t, err, "mapper has no Mappings") }) + + t.Run("invalid mapping", func(t *testing.T) { + m := []*Mapping{NewMapping("1").FromMapKey("key1").FromField("F1")} + _, err := mapFrom[string](map[string]any{"key1": "v1", "key2": 2}, m) + assert.ErrorContains(t, err, "mapping has both fromField and fromMapKey") + + m = []*Mapping{NewMapping("1").ToMapKey("key1").ToField("F1")} + _, err = mapFrom[string]("input", m) + assert.ErrorContains(t, err, "mapping has both toField and toMapKey") + + m = []*Mapping{NewMapping("1").FromField("F1")} + _, err = mapFrom[string](generic.PtrOf("input"), m) + assert.ErrorContains(t, err, "mapping has fromField but input is not struct or struct ptr") + + m = []*Mapping{NewMapping("1").FromMapKey("key1")} + _, err = mapFrom[string](generic.PtrOf("input"), m) + assert.ErrorContains(t, err, "mapping has FromKey but input is not a map") + + m = []*Mapping{NewMapping("1").ToField("F1")} + _, err = mapFrom[string]("input", m) + assert.ErrorContains(t, err, "mapping has toField but output is not a struct") + + m = []*Mapping{NewMapping("1").ToMapKey("key1")} + _, err = mapFrom[string]("input", m) + assert.ErrorContains(t, err, "mapping has toMapKey but output is not a map") + }) } diff --git a/compose/generic_graph.go b/compose/generic_graph.go index 7cc5ba8..ef30b92 100644 --- a/compose/generic_graph.go +++ b/compose/generic_graph.go @@ -140,5 +140,5 @@ func (g *Graph[I, O]) Compile(ctx context.Context, opts ...GraphCompileOption) ( } func (g *Graph[I, O]) fieldMapper() fieldMapper { - return defaultFieldMapper[O]{} + return defaultFieldMapper[I]{} } diff --git a/compose/graph.go b/compose/graph.go index e5a8318..03d50cf 100644 --- a/compose/graph.go +++ b/compose/graph.go @@ -618,28 +618,28 @@ func (g *graph) validateAndInferType(startNode, endNode string, mappings ...*Map fromType := startNodeOutputType toType := endNodeInputType - if len(m.FromMapKey) > 0 { - if fromType, err = checkAndExtractMapValueType(m.FromMapKey, fromType); err != nil { - return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%v] start node's output failed, %w", startNode, endNode, m, err) + if len(m.fromMapKey) > 0 { + if fromType, err = checkAndExtractMapValueType(fromType); err != nil { + return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%s] start node's output failed, %w", startNode, endNode, m, err) } - } else if len(m.FromField) > 0 { - if fromType, err = checkAndExtractFieldType(m.FromField, fromType); err != nil { - return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%v] start node's output failed, %w", startNode, endNode, m, err) + } else if len(m.fromField) > 0 { + if fromType, err = checkAndExtractFieldType(m.fromField, fromType); err != nil { + return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%s] start node's output failed, %w", startNode, endNode, m, err) } } - if len(m.ToMapKey) > 0 { - if toType, err = checkAndExtractMapValueType(m.ToMapKey, toType); err != nil { - return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%v] end node's input failed, %w", startNode, endNode, m, err) + if len(m.toMapKey) > 0 { + if toType, err = checkAndExtractMapValueType(toType); err != nil { + return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%s] end node's input failed, %w", startNode, endNode, m, err) } - } else if len(m.ToField) > 0 { - if toType, err = checkAndExtractFieldType(m.ToField, toType); err != nil { - return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%v] end node's input failed, %w", startNode, endNode, m, err) + } else if len(m.toField) > 0 { + if toType, err = checkAndExtractFieldType(m.toField, toType); err != nil { + return fmt.Errorf("graph edge [%s]-[%s]: check mapping[%s] end node's input failed, %w", startNode, endNode, m, err) } } if checkAssignable(fromType, toType) == assignableTypeMustNot { - return fmt.Errorf("graph edge[%s]-[%s]: after mapping[%v], start node's output type[%s] and end node's input type[%s] mismatch", + return fmt.Errorf("graph edge[%s]-[%s]: after mapping[%s], start node's output type[%s] and end node's input type[%s] mismatch", m, startNode, endNode, fromType.String(), toType.String()) } } diff --git a/compose/workflow.go b/compose/workflow.go index 8821b42..7205015 100644 --- a/compose/workflow.go +++ b/compose/workflow.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package compose import ( @@ -5,6 +21,7 @@ import ( "errors" "fmt" "reflect" + "strings" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" @@ -17,17 +34,74 @@ import ( ) type Mapping struct { - From string + fromNodeKey string - FromField string - FromMapKey string + fromField string + fromMapKey string - ToField string - ToMapKey string + toField string + toMapKey string } func (m *Mapping) empty() bool { - return len(m.FromField) == 0 && len(m.FromMapKey) == 0 && len(m.ToField) == 0 && len(m.ToMapKey) == 0 + return len(m.fromField) == 0 && len(m.fromMapKey) == 0 && len(m.toField) == 0 && len(m.toMapKey) == 0 +} + +func (m *Mapping) FromField(fieldName string) *Mapping { + m.fromField = fieldName + return m +} + +func (m *Mapping) ToField(fieldName string) *Mapping { + m.toField = fieldName + return m +} + +func (m *Mapping) FromMapKey(mapKey string) *Mapping { + m.fromMapKey = mapKey + return m +} + +func (m *Mapping) ToMapKey(mapKey string) *Mapping { + m.toMapKey = mapKey + return m +} + +func (m *Mapping) String() string { + var sb strings.Builder + sb.WriteString("from ") + + if m.fromMapKey != "" { + sb.WriteString(m.fromMapKey) + sb.WriteString("(map key) of ") + } + + if m.fromField != "" { + sb.WriteString(m.fromField) + sb.WriteString("(field) of ") + } + + sb.WriteString("node '") + sb.WriteString(m.fromNodeKey) + sb.WriteString("'") + + if m.toField != "" { + sb.WriteString(" to ") + sb.WriteString(m.toField) + sb.WriteString("(field)") + } + + if m.toMapKey != "" { + sb.WriteString(" to ") + sb.WriteString(m.toMapKey) + sb.WriteString("(map key)") + } + + return sb.String() +} + +func NewMapping(fromNodeKey string) *Mapping { + return &Mapping{fromNodeKey: fromNodeKey} } type WorkflowNode struct { @@ -114,7 +188,7 @@ func convertAddNodeOpts(opts []WorkflowAddNodeOpt) []GraphAddNodeOpt { } func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.ChatModel, opts ...WorkflowAddNodeOpt) *WorkflowNode { - node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[*schema.Message]{}} + node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Message]{}} wf.nodes[key] = node if wf.err != nil { @@ -130,7 +204,7 @@ func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.ChatModel } func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.ChatTemplate, opts ...WorkflowAddNodeOpt) *WorkflowNode { - node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Message]{}} + node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[map[string]any]{}} wf.nodes[key] = node if wf.err != nil { @@ -147,7 +221,7 @@ func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.Ch } func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...WorkflowAddNodeOpt) *WorkflowNode { - node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Message]{}} + node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[*schema.Message]{}} wf.nodes[key] = node if wf.err != nil { @@ -163,7 +237,7 @@ func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...Wor } func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retriever, opts ...WorkflowAddNodeOpt) *WorkflowNode { - node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Document]{}} + node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[string]{}} wf.nodes[key] = node if wf.err != nil { @@ -179,7 +253,7 @@ func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retri } func (wf *Workflow[I, O]) AddEmbeddingNode(key string, embedding embedding.Embedder, opts ...WorkflowAddNodeOpt) *WorkflowNode { - node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[][]float64]{}} + node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]string]{}} wf.nodes[key] = node if wf.err != nil { @@ -195,7 +269,7 @@ func (wf *Workflow[I, O]) AddEmbeddingNode(key string, embedding embedding.Embed } func (wf *Workflow[I, O]) AddIndexerNode(key string, indexer indexer.Indexer, opts ...WorkflowAddNodeOpt) *WorkflowNode { - node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]string]{}} + node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Document]{}} wf.nodes[key] = node if wf.err != nil { @@ -211,7 +285,7 @@ func (wf *Workflow[I, O]) AddIndexerNode(key string, indexer indexer.Indexer, op } func (wf *Workflow[I, O]) AddLoaderNode(key string, loader document.Loader, opts ...WorkflowAddNodeOpt) *WorkflowNode { - node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Document]{}} + node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[document.Source]{}} wf.nodes[key] = node if wf.err != nil { @@ -280,7 +354,7 @@ func (n *WorkflowNode) AddInput(inputs ...*Mapping) *WorkflowNode { return n } -func (wf *Workflow[I, O]) AddEnd(inputs []*Mapping) { +func (wf *Workflow[I, O]) AddEnd(inputs ...*Mapping) { wf.end = inputs wf.endFieldMapper = defaultFieldMapper[O]{} } @@ -314,9 +388,6 @@ func (wf *Workflow[I, O]) addEdgesWithMapping() (err error) { for _, node := range wf.nodes { toNode = node.key fm := node.fieldMapper - if fm == nil { - return fmt.Errorf("workflow has no field mapper, node = %s", toNode) - } if len(node.inputs) == 0 { return fmt.Errorf("workflow node = %s has no input", toNode) @@ -325,7 +396,7 @@ func (wf *Workflow[I, O]) addEdgesWithMapping() (err error) { fromNode2Mappings := make(map[string][]*Mapping, len(node.inputs)) for i := range node.inputs { input := node.inputs[i] - fromNodeKey := input.From + fromNodeKey := input.fromNodeKey fromNode2Mappings[fromNodeKey] = append(fromNode2Mappings[fromNodeKey], input) } @@ -344,15 +415,16 @@ func (wf *Workflow[I, O]) addEdgesWithMapping() (err error) { } } - fm := wf.endFieldMapper - if fm == nil { - return errors.New("workflow has no end field mapper") + if len(wf.end) == 0 { + return errors.New("workflow END has no input mapping") } + fm := wf.endFieldMapper + fromNode2EndMappings := make(map[string][]*Mapping, len(wf.end)) for i := range wf.end { input := wf.end[i] - fromNodeKey := input.From + fromNodeKey := input.fromNodeKey fromNode2EndMappings[fromNodeKey] = append(fromNode2EndMappings[fromNodeKey], input) } diff --git a/compose/workflow_test.go b/compose/workflow_test.go index a60b66d..67b4083 100644 --- a/compose/workflow_test.go +++ b/compose/workflow_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package compose import ( @@ -7,9 +23,18 @@ import ( "testing" "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/internal/mock/components/embedding" + "github.com/cloudwego/eino/internal/mock/components/indexer" + "github.com/cloudwego/eino/internal/mock/components/retriever" + "github.com/cloudwego/eino/schema" ) func TestWorkflow(t *testing.T) { + ctx := context.Background() + type structA struct { Field1 string Field2 int @@ -29,28 +54,25 @@ func TestWorkflow(t *testing.T) { w := NewWorkflow[*structA, string]() - w. + n := w. AddLambdaNode( "B", InvokableLambda(func(context.Context, string) (*structB, error) { return &structB{Field1: "1", Field2: 2}, nil }), - WithWorkflowNodeName("node B")). - AddInput(&Mapping{ - From: START, - FromField: "Field1", - }) + WithWorkflowNodeName("node B")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "workflow node = B has no input") + + n.AddInput(NewMapping(START).FromField("Field1")) w. AddLambdaNode( "C", - InvokableLambda(func(_ context.Context, in int) (map[string]string, error) { + InvokableLambda(func(_ context.Context, in any) (any, error) { return map[string]string{"key2": fmt.Sprintf("%d", in)}, nil })). - AddInput(&Mapping{ - From: START, - FromField: "Field2", - }) + AddInput(NewMapping(START).FromField("Field2")) w. AddLambdaNode( @@ -58,10 +80,7 @@ func TestWorkflow(t *testing.T) { InvokableLambda(func(_ context.Context, in []any) ([]any, error) { return in, nil })). - AddInput(&Mapping{ - From: START, - FromField: "Field3", - }) + AddInput(NewMapping(START).FromField("Field3")) w. AddLambdaNode( @@ -69,24 +88,21 @@ func TestWorkflow(t *testing.T) { InvokableLambda(func(_ context.Context, in *structE) (string, error) { return in.Field1 + "_" + in.Field2 + "_" + fmt.Sprintf("%v", in.Field3[0]), nil })). - AddInput(&Mapping{ - From: "B", - FromField: "Field1", - ToField: "Field1", - }, &Mapping{ - From: "C", - FromMapKey: "key2", - ToField: "Field2", - }, &Mapping{ - From: "D", - ToField: "Field3", - }) - - w.AddEnd([]*Mapping{{ - From: "E", - }}) + AddInput( + NewMapping("B").FromField("Field1").ToField("Field1"), + NewMapping("C").FromMapKey("key2").ToField("Field2"), + NewMapping("D").ToField("Field3"), + ) + + t.Run("compile without add end", func(t *testing.T) { + w1 := NewWorkflow[*structA, string]() + w1.gg.edges = make(map[string][]string) + _, err = w1.Compile(ctx) + assert.ErrorContains(t, err, "workflow END has no input mapping") + }) + + w.AddEnd(NewMapping("E")) - ctx := context.Background() compiled, err := w.Compile(ctx) assert.NoError(t, err) @@ -118,3 +134,45 @@ func TestWorkflow(t *testing.T) { assert.Equal(t, "1_2_1", chunk) } } + +func TestWorkflowCompile(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + + t.Run("upstream not map, mapping has FromMapKey", func(t *testing.T) { + w := NewWorkflow[map[string]any, string]() + + w.AddChatTemplateNode("prompt", prompt.FromMessages(schema.Jinja2, schema.SystemMessage("your name is {{ name }}"))).AddInput(NewMapping(START)) + + w.AddEnd(NewMapping("prompt").FromMapKey("key1")) + + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[[]*schema.Message] is not a map") + }) + + t.Run("upstream not struct/struct ptr, mapping has FromField", func(t *testing.T) { + w := NewWorkflow[[]*schema.Document, []string]() + + w.AddIndexerNode("indexer", indexer.NewMockIndexer(ctrl)).AddInput(NewMapping(START).FromField("F1")) + w.AddEnd(NewMapping("indexer")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[[]*schema.Document] is not a struct") + }) + + t.Run("downstream not map, mapping has ToMapKey", func(t *testing.T) { + w := NewWorkflow[string, []*schema.Document]() + + w.AddRetrieverNode("retriever", retriever.NewMockRetriever(ctrl)).AddInput(NewMapping(START).ToMapKey("key1")) + w.AddEnd(NewMapping("retriever")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[string] is not a map") + }) + + t.Run("downstream not struct/struct ptr, mapping has ToField", func(t *testing.T) { + w := NewWorkflow[[]string, [][]float64]() + w.AddEmbeddingNode("embedder", embedding.NewMockEmbedder(ctrl)).AddInput(NewMapping(START).ToField("F1")) + w.AddEnd(NewMapping("embedder")) + _, err := w.Compile(ctx) + assert.ErrorContains(t, err, "type[[]string] is not a struct") + }) +}