diff --git a/compose/workflow.go b/compose/workflow.go index 11fa54a..1a14f8d 100644 --- a/compose/workflow.go +++ b/compose/workflow.go @@ -33,6 +33,7 @@ import ( "github.com/cloudwego/eino/utils/generic" ) +// Mapping is the mapping from one node's output to current node's input. type Mapping struct { fromNodeKey string @@ -47,26 +48,31 @@ func (m *Mapping) empty() bool { return len(m.fromField) == 0 && len(m.fromMapKey) == 0 && len(m.toField) == 0 && len(m.toMapKey) == 0 } +// FromField chooses a field value from fromNode's output struct or struct pointer with the specific field name, to serve as the source of the Mapping. func (m *Mapping) FromField(fieldName string) *Mapping { m.fromField = fieldName return m } +// ToField chooses a field from currentNode's input struct or struct pointer with the specific field name, to serve as the destination of the Mapping. func (m *Mapping) ToField(fieldName string) *Mapping { m.toField = fieldName return m } +// FromMapKey chooses a map entry from fromNode's output map with the specific key, to serve as the source of the Mapping. func (m *Mapping) FromMapKey(mapKey string) *Mapping { m.fromMapKey = mapKey return m } +// ToMapKey chooses a map entry from currentNode's input map with the specific key, to serve as the destination of the Mapping. func (m *Mapping) ToMapKey(mapKey string) *Mapping { m.toMapKey = mapKey return m } +// String returns the string representation of the Mapping. func (m *Mapping) String() string { var sb strings.Builder sb.WriteString("from ") @@ -101,16 +107,20 @@ func (m *Mapping) String() string { return sb.String() } +// NewMapping creates a new Mapping with the specified fromNodeKey. func NewMapping(fromNodeKey string) *Mapping { return &Mapping{fromNodeKey: fromNodeKey} } +// WorkflowNode is the node of the Workflow. type WorkflowNode struct { key string inputs []*Mapping fieldMapper fieldMapper } +// Workflow is wrapper of Graph, replacing AddEdge with declaring Mapping between one node's output and current node's input. +// Under the hood it uses NodeTriggerMode(AllPredecessor), so does not support branches or cycles. type Workflow[I, O any] struct { gg *Graph[I, O] @@ -120,6 +130,7 @@ type Workflow[I, O any] struct { err error } +// NewWorkflow creates a new Workflow. func NewWorkflow[I, O any](opts ...NewGraphOption) *Workflow[I, O] { wf := &Workflow[I, O]{ gg: NewGraph[I, O](opts...), @@ -130,65 +141,21 @@ func NewWorkflow[I, O any](opts ...NewGraphOption) *Workflow[I, O] { return wf } -type WorkflowCompileOption GraphCompileOption - -func WithWorkflowMaxRunStep(maxSteps int) WorkflowCompileOption { - return WorkflowCompileOption(WithMaxRunSteps(maxSteps)) -} - -func WithWorkflowName(name string) WorkflowCompileOption { - return WorkflowCompileOption(WithGraphName(name)) -} - -func (wf *Workflow[I, O]) Compile(ctx context.Context, opts ...WorkflowCompileOption) (Runnable[I, O], error) { +func (wf *Workflow[I, O]) Compile(ctx context.Context, opts ...GraphCompileOption) (Runnable[I, O], error) { if wf.err != nil { return nil, wf.err } - gCompileOpts := make([]GraphCompileOption, 0, len(opts)+1) - for _, opt := range opts { - gCompileOpts = append(gCompileOpts, GraphCompileOption(opt)) - } - gCompileOpts = append(gCompileOpts, WithNodeTriggerMode(AllPredecessor)) + opts = append(opts, WithNodeTriggerMode(AllPredecessor)) if err := wf.addEdgesWithMapping(); err != nil { return nil, err } - return wf.gg.Compile(ctx, gCompileOpts...) -} - -type WorkflowAddNodeOpt GraphAddNodeOpt - -func WithWorkflowNodeName(name string) WorkflowAddNodeOpt { - return WorkflowAddNodeOpt(WithNodeName(name)) -} - -func WithWorkflowStatePreHandler[I, S any](pre StatePreHandler[I, S]) WorkflowAddNodeOpt { - return WorkflowAddNodeOpt(WithStatePreHandler(pre)) -} - -func WithWorkflowStatePostHandler[O, S any](post StatePostHandler[O, S]) WorkflowAddNodeOpt { - return WorkflowAddNodeOpt(WithStatePostHandler(post)) -} - -func WithWorkflowStreamStatePreHandler[I, S any](pre StreamStatePreHandler[I, S]) WorkflowAddNodeOpt { - return WorkflowAddNodeOpt(WithStreamStatePreHandler(pre)) -} - -func WithWorkflowStreamStatePostHandler[O, S any](post StreamStatePostHandler[O, S]) WorkflowAddNodeOpt { - return WorkflowAddNodeOpt(WithStreamStatePostHandler(post)) -} - -func convertAddNodeOpts(opts []WorkflowAddNodeOpt) []GraphAddNodeOpt { - graphOpts := make([]GraphAddNodeOpt, 0, len(opts)) - for _, opt := range opts { - graphOpts = append(graphOpts, GraphAddNodeOpt(opt)) - } - return graphOpts + return wf.gg.Compile(ctx, opts...) } -func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.ChatModel, opts ...WorkflowAddNodeOpt) *WorkflowNode { +func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.ChatModel, opts ...GraphAddNodeOpt) *WorkflowNode { node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Message]{}} wf.nodes[key] = node @@ -196,7 +163,12 @@ func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.ChatModel return node } - err := wf.gg.AddChatModelNode(key, chatModel, convertAddNodeOpts(opts)...) + options := getGraphAddNodeOpts(opts...) + if len(options.nodeOptions.inputKey) > 0 { + node.fieldMapper = defaultFieldMapper[map[string]any]{} + } + + err := wf.gg.AddChatModelNode(key, chatModel, opts...) if err != nil { wf.err = err return node @@ -204,7 +176,7 @@ func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.ChatModel return node } -func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.ChatTemplate, opts ...WorkflowAddNodeOpt) *WorkflowNode { +func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.ChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode { node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[map[string]any]{}} wf.nodes[key] = node @@ -212,7 +184,12 @@ func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.Ch return node } - err := wf.gg.AddChatTemplateNode(key, chatTemplate, convertAddNodeOpts(opts)...) + options := getGraphAddNodeOpts(opts...) + if len(options.nodeOptions.inputKey) > 0 { + node.fieldMapper = defaultFieldMapper[map[string]any]{} + } + + err := wf.gg.AddChatTemplateNode(key, chatTemplate, opts...) if err != nil { wf.err = err return node @@ -221,7 +198,7 @@ func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.Ch return node } -func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...WorkflowAddNodeOpt) *WorkflowNode { +func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode { node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[*schema.Message]{}} wf.nodes[key] = node @@ -229,7 +206,12 @@ func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...Wor return node } - err := wf.gg.AddToolsNode(key, tools, convertAddNodeOpts(opts)...) + options := getGraphAddNodeOpts(opts...) + if len(options.nodeOptions.inputKey) > 0 { + node.fieldMapper = defaultFieldMapper[map[string]any]{} + } + + err := wf.gg.AddToolsNode(key, tools, opts...) if err != nil { wf.err = err return node @@ -237,7 +219,7 @@ func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...Wor return node } -func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retriever, opts ...WorkflowAddNodeOpt) *WorkflowNode { +func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retriever, opts ...GraphAddNodeOpt) *WorkflowNode { node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[string]{}} wf.nodes[key] = node @@ -245,7 +227,12 @@ func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retri return node } - err := wf.gg.AddRetrieverNode(key, retriever, convertAddNodeOpts(opts)...) + options := getGraphAddNodeOpts(opts...) + if len(options.nodeOptions.inputKey) > 0 { + node.fieldMapper = defaultFieldMapper[map[string]any]{} + } + + err := wf.gg.AddRetrieverNode(key, retriever, opts...) if err != nil { wf.err = err return node @@ -253,7 +240,7 @@ func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retri return node } -func (wf *Workflow[I, O]) AddEmbeddingNode(key string, embedding embedding.Embedder, opts ...WorkflowAddNodeOpt) *WorkflowNode { +func (wf *Workflow[I, O]) AddEmbeddingNode(key string, embedding embedding.Embedder, opts ...GraphAddNodeOpt) *WorkflowNode { node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]string]{}} wf.nodes[key] = node @@ -261,7 +248,12 @@ func (wf *Workflow[I, O]) AddEmbeddingNode(key string, embedding embedding.Embed return node } - err := wf.gg.AddEmbeddingNode(key, embedding, convertAddNodeOpts(opts)...) + options := getGraphAddNodeOpts(opts...) + if len(options.nodeOptions.inputKey) > 0 { + node.fieldMapper = defaultFieldMapper[map[string]any]{} + } + + err := wf.gg.AddEmbeddingNode(key, embedding, opts...) if err != nil { wf.err = err return node @@ -269,7 +261,7 @@ func (wf *Workflow[I, O]) AddEmbeddingNode(key string, embedding embedding.Embed return node } -func (wf *Workflow[I, O]) AddIndexerNode(key string, indexer indexer.Indexer, opts ...WorkflowAddNodeOpt) *WorkflowNode { +func (wf *Workflow[I, O]) AddIndexerNode(key string, indexer indexer.Indexer, opts ...GraphAddNodeOpt) *WorkflowNode { node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Document]{}} wf.nodes[key] = node @@ -277,7 +269,12 @@ func (wf *Workflow[I, O]) AddIndexerNode(key string, indexer indexer.Indexer, op return node } - err := wf.gg.AddIndexerNode(key, indexer, convertAddNodeOpts(opts)...) + options := getGraphAddNodeOpts(opts...) + if len(options.nodeOptions.inputKey) > 0 { + node.fieldMapper = defaultFieldMapper[map[string]any]{} + } + + err := wf.gg.AddIndexerNode(key, indexer, opts...) if err != nil { wf.err = err return node @@ -285,7 +282,7 @@ func (wf *Workflow[I, O]) AddIndexerNode(key string, indexer indexer.Indexer, op return node } -func (wf *Workflow[I, O]) AddLoaderNode(key string, loader document.Loader, opts ...WorkflowAddNodeOpt) *WorkflowNode { +func (wf *Workflow[I, O]) AddLoaderNode(key string, loader document.Loader, opts ...GraphAddNodeOpt) *WorkflowNode { node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[document.Source]{}} wf.nodes[key] = node @@ -293,7 +290,12 @@ func (wf *Workflow[I, O]) AddLoaderNode(key string, loader document.Loader, opts return node } - err := wf.gg.AddLoaderNode(key, loader, convertAddNodeOpts(opts)...) + options := getGraphAddNodeOpts(opts...) + if len(options.nodeOptions.inputKey) > 0 { + node.fieldMapper = defaultFieldMapper[map[string]any]{} + } + + err := wf.gg.AddLoaderNode(key, loader, opts...) if err != nil { wf.err = err return node @@ -301,7 +303,7 @@ func (wf *Workflow[I, O]) AddLoaderNode(key string, loader document.Loader, opts return node } -func (wf *Workflow[I, O]) AddDocumentTransformerNode(key string, transformer document.Transformer, opts ...WorkflowAddNodeOpt) *WorkflowNode { +func (wf *Workflow[I, O]) AddDocumentTransformerNode(key string, transformer document.Transformer, opts ...GraphAddNodeOpt) *WorkflowNode { node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Document]{}} wf.nodes[key] = node @@ -309,7 +311,12 @@ func (wf *Workflow[I, O]) AddDocumentTransformerNode(key string, transformer doc return node } - err := wf.gg.AddDocumentTransformerNode(key, transformer, convertAddNodeOpts(opts)...) + options := getGraphAddNodeOpts(opts...) + if len(options.nodeOptions.inputKey) > 0 { + node.fieldMapper = defaultFieldMapper[map[string]any]{} + } + + err := wf.gg.AddDocumentTransformerNode(key, transformer, opts...) if err != nil { wf.err = err return node @@ -317,7 +324,7 @@ func (wf *Workflow[I, O]) AddDocumentTransformerNode(key string, transformer doc return node } -func (wf *Workflow[I, O]) AddGraphNode(key string, graph AnyGraph, opts ...WorkflowAddNodeOpt) *WorkflowNode { +func (wf *Workflow[I, O]) AddGraphNode(key string, graph AnyGraph, opts ...GraphAddNodeOpt) *WorkflowNode { node := &WorkflowNode{key: key, fieldMapper: graph.fieldMapper()} wf.nodes[key] = node @@ -325,7 +332,12 @@ func (wf *Workflow[I, O]) AddGraphNode(key string, graph AnyGraph, opts ...Workf return node } - err := wf.gg.AddGraphNode(key, graph, convertAddNodeOpts(opts)...) + options := getGraphAddNodeOpts(opts...) + if len(options.nodeOptions.inputKey) > 0 { + node.fieldMapper = defaultFieldMapper[map[string]any]{} + } + + err := wf.gg.AddGraphNode(key, graph, opts...) if err != nil { wf.err = err return node @@ -333,7 +345,7 @@ func (wf *Workflow[I, O]) AddGraphNode(key string, graph AnyGraph, opts ...Workf return node } -func (wf *Workflow[I, O]) AddLambdaNode(key string, lambda *Lambda, opts ...WorkflowAddNodeOpt) *WorkflowNode { +func (wf *Workflow[I, O]) AddLambdaNode(key string, lambda *Lambda, opts ...GraphAddNodeOpt) *WorkflowNode { node := &WorkflowNode{key: key, fieldMapper: lambda.fieldMapper} wf.nodes[key] = node @@ -341,7 +353,12 @@ func (wf *Workflow[I, O]) AddLambdaNode(key string, lambda *Lambda, opts ...Work return node } - err := wf.gg.AddLambdaNode(key, lambda, convertAddNodeOpts(opts)...) + options := getGraphAddNodeOpts(opts...) + if len(options.nodeOptions.inputKey) > 0 { + node.fieldMapper = defaultFieldMapper[map[string]any]{} + } + + err := wf.gg.AddLambdaNode(key, lambda, opts...) if err != nil { wf.err = err return node diff --git a/compose/workflow_test.go b/compose/workflow_test.go index cdd250b..5f1a2dd 100644 --- a/compose/workflow_test.go +++ b/compose/workflow_test.go @@ -58,6 +58,10 @@ func TestWorkflow(t *testing.T) { temp string } + type structEnd struct { + Field1 string + } + subGraph := NewGraph[string, *structB]() _ = subGraph.AddLambdaNode( "1", @@ -78,15 +82,23 @@ func TestWorkflow(t *testing.T) { "1", InvokableLambda(func(_ context.Context, in []any) ([]any, error) { return in, nil - })). + }), + WithOutputKey("key")). AddInput(NewMapping(START)) - subWorkflow.AddEnd(NewMapping("1")) + subWorkflow.AddLambdaNode( + "2", + InvokableLambda(func(_ context.Context, in []any) ([]any, error) { + return in, nil + }), + WithInputKey("key")). + AddInput(NewMapping("1")) + subWorkflow.AddEnd(NewMapping("2")) - w := NewWorkflow[*structA, string](WithGenLocalState(func(context.Context) *state { return &state{} })) + w := NewWorkflow[*structA, *structEnd](WithGenLocalState(func(context.Context) *state { return &state{} })) w. - AddGraphNode("B", subGraph, WithWorkflowNodeName("node B"), - WithWorkflowStatePostHandler(func(ctx context.Context, out *structB, state *state) (*structB, error) { + AddGraphNode("B", subGraph, + WithStatePostHandler(func(ctx context.Context, out *structB, state *state) (*structB, error) { state.temp = out.Field1 return out, nil })). @@ -115,7 +127,7 @@ func TestWorkflow(t *testing.T) { return in, nil }), nil }), - WithWorkflowStreamStatePreHandler(func(ctx context.Context, in *schema.StreamReader[structE], state *state) (*schema.StreamReader[structE], error) { + WithStreamStatePreHandler(func(ctx context.Context, in *schema.StreamReader[structE], state *state) (*schema.StreamReader[structE], error) { temp := state.temp return schema.StreamReaderWithConvert(in, func(v structE) (structE, error) { if len(v.Field3) > 0 { @@ -125,7 +137,7 @@ func TestWorkflow(t *testing.T) { return v, nil }), nil }), - WithWorkflowStreamStatePostHandler(func(ctx context.Context, out *schema.StreamReader[structE], state *state) (*schema.StreamReader[structE], error) { + WithStreamStatePostHandler(func(ctx context.Context, out *schema.StreamReader[structE], state *state) (*schema.StreamReader[structE], error) { return schema.StreamReaderWithConvert(out, func(v structE) (structE, error) { if len(v.Field1) > 0 { v.Field1 = v.Field1 + "+Post" @@ -145,7 +157,7 @@ func TestWorkflow(t *testing.T) { InvokableLambda(func(ctx context.Context, in map[string]any) (string, error) { return fmt.Sprintf("%v_%v_%v_%v_%v", in["key1"], in["key2"], in["key3"], in["B"], in["state_temp"]), nil }), - WithWorkflowStatePreHandler(func(ctx context.Context, in map[string]any, state *state) (map[string]any, error) { + WithStatePreHandler(func(ctx context.Context, in map[string]any, state *state) (map[string]any, error) { in["state_temp"] = state.temp return in, nil }), @@ -157,7 +169,7 @@ func TestWorkflow(t *testing.T) { NewMapping("E").FromField("Field3").ToMapKey("key3"), ) - w.AddEnd(NewMapping("F")) + w.AddEnd(NewMapping("F").ToField("Field1")) compiled, err := w.Compile(ctx) assert.NoError(t, err) @@ -171,7 +183,7 @@ func TestWorkflow(t *testing.T) { } out, err := compiled.Invoke(ctx, input) assert.NoError(t, err) - assert.Equal(t, "E:1+Post_E:2_[1 good Pre:1]_33_1", out) + assert.Equal(t, &structEnd{"E:1+Post_E:2_[1 good Pre:1]_33_1"}, out) outStream, err := compiled.Stream(ctx, input) assert.NoError(t, err) @@ -187,7 +199,7 @@ func TestWorkflow(t *testing.T) { return } - assert.Equal(t, "E:1+Post_E:2_[1 good Pre:1]_33_1", chunk) + assert.Equal(t, &structEnd{"E:1+Post_E:2_[1 good Pre:1]_33_1"}, chunk) } }