Skip to content

Commit

Permalink
input key & output key
Browse files Browse the repository at this point in the history
Change-Id: Idd6d633c877ae1b48ac0fda7047f9d32dd8da87c
  • Loading branch information
shentongmartin committed Jan 6, 2025
1 parent c7150b4 commit f0c5bf7
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 78 deletions.
151 changes: 84 additions & 67 deletions compose/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/cloudwego/eino/utils/generic"
)

// Mapping is the mapping from one node's output to current node's input.
type Mapping struct {
fromNodeKey string

Expand All @@ -47,26 +48,31 @@ func (m *Mapping) empty() bool {
return len(m.fromField) == 0 && len(m.fromMapKey) == 0 && len(m.toField) == 0 && len(m.toMapKey) == 0
}

// FromField chooses a field value from fromNode's output struct or struct pointer with the specific field name, to serve as the source of the Mapping.
func (m *Mapping) FromField(fieldName string) *Mapping {
m.fromField = fieldName
return m
}

// ToField chooses a field from currentNode's input struct or struct pointer with the specific field name, to serve as the destination of the Mapping.
func (m *Mapping) ToField(fieldName string) *Mapping {
m.toField = fieldName
return m
}

// FromMapKey chooses a map entry from fromNode's output map with the specific key, to serve as the source of the Mapping.
func (m *Mapping) FromMapKey(mapKey string) *Mapping {
m.fromMapKey = mapKey
return m
}

// ToMapKey chooses a map entry from currentNode's input map with the specific key, to serve as the destination of the Mapping.
func (m *Mapping) ToMapKey(mapKey string) *Mapping {
m.toMapKey = mapKey
return m
}

// String returns the string representation of the Mapping.
func (m *Mapping) String() string {
var sb strings.Builder
sb.WriteString("from ")
Expand Down Expand Up @@ -101,16 +107,20 @@ func (m *Mapping) String() string {
return sb.String()
}

// NewMapping creates a new Mapping with the specified fromNodeKey.
func NewMapping(fromNodeKey string) *Mapping {
return &Mapping{fromNodeKey: fromNodeKey}
}

// WorkflowNode is the node of the Workflow.
type WorkflowNode struct {
key string
inputs []*Mapping
fieldMapper fieldMapper
}

// Workflow is wrapper of Graph, replacing AddEdge with declaring Mapping between one node's output and current node's input.
// Under the hood it uses NodeTriggerMode(AllPredecessor), so does not support branches or cycles.
type Workflow[I, O any] struct {
gg *Graph[I, O]

Expand All @@ -120,6 +130,7 @@ type Workflow[I, O any] struct {
err error
}

// NewWorkflow creates a new Workflow.
func NewWorkflow[I, O any](opts ...NewGraphOption) *Workflow[I, O] {
wf := &Workflow[I, O]{
gg: NewGraph[I, O](opts...),
Expand All @@ -130,89 +141,55 @@ func NewWorkflow[I, O any](opts ...NewGraphOption) *Workflow[I, O] {
return wf
}

type WorkflowCompileOption GraphCompileOption

func WithWorkflowMaxRunStep(maxSteps int) WorkflowCompileOption {
return WorkflowCompileOption(WithMaxRunSteps(maxSteps))
}

func WithWorkflowName(name string) WorkflowCompileOption {
return WorkflowCompileOption(WithGraphName(name))
}

func (wf *Workflow[I, O]) Compile(ctx context.Context, opts ...WorkflowCompileOption) (Runnable[I, O], error) {
func (wf *Workflow[I, O]) Compile(ctx context.Context, opts ...GraphCompileOption) (Runnable[I, O], error) {
if wf.err != nil {
return nil, wf.err
}

gCompileOpts := make([]GraphCompileOption, 0, len(opts)+1)
for _, opt := range opts {
gCompileOpts = append(gCompileOpts, GraphCompileOption(opt))
}
gCompileOpts = append(gCompileOpts, WithNodeTriggerMode(AllPredecessor))
opts = append(opts, WithNodeTriggerMode(AllPredecessor))

if err := wf.addEdgesWithMapping(); err != nil {
return nil, err
}

return wf.gg.Compile(ctx, gCompileOpts...)
}

type WorkflowAddNodeOpt GraphAddNodeOpt

func WithWorkflowNodeName(name string) WorkflowAddNodeOpt {
return WorkflowAddNodeOpt(WithNodeName(name))
}

func WithWorkflowStatePreHandler[I, S any](pre StatePreHandler[I, S]) WorkflowAddNodeOpt {
return WorkflowAddNodeOpt(WithStatePreHandler(pre))
}

func WithWorkflowStatePostHandler[O, S any](post StatePostHandler[O, S]) WorkflowAddNodeOpt {
return WorkflowAddNodeOpt(WithStatePostHandler(post))
}

func WithWorkflowStreamStatePreHandler[I, S any](pre StreamStatePreHandler[I, S]) WorkflowAddNodeOpt {
return WorkflowAddNodeOpt(WithStreamStatePreHandler(pre))
}

func WithWorkflowStreamStatePostHandler[O, S any](post StreamStatePostHandler[O, S]) WorkflowAddNodeOpt {
return WorkflowAddNodeOpt(WithStreamStatePostHandler(post))
}

func convertAddNodeOpts(opts []WorkflowAddNodeOpt) []GraphAddNodeOpt {
graphOpts := make([]GraphAddNodeOpt, 0, len(opts))
for _, opt := range opts {
graphOpts = append(graphOpts, GraphAddNodeOpt(opt))
}
return graphOpts
return wf.gg.Compile(ctx, opts...)
}

func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.ChatModel, opts ...WorkflowAddNodeOpt) *WorkflowNode {
func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.ChatModel, opts ...GraphAddNodeOpt) *WorkflowNode {
node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Message]{}}
wf.nodes[key] = node

if wf.err != nil {
return node
}

err := wf.gg.AddChatModelNode(key, chatModel, convertAddNodeOpts(opts)...)
options := getGraphAddNodeOpts(opts...)
if len(options.nodeOptions.inputKey) > 0 {
node.fieldMapper = defaultFieldMapper[map[string]any]{}
}

err := wf.gg.AddChatModelNode(key, chatModel, opts...)
if err != nil {
wf.err = err
return node
}
return node
}

func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.ChatTemplate, opts ...WorkflowAddNodeOpt) *WorkflowNode {
func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.ChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode {
node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[map[string]any]{}}
wf.nodes[key] = node

if wf.err != nil {
return node
}

err := wf.gg.AddChatTemplateNode(key, chatTemplate, convertAddNodeOpts(opts)...)
options := getGraphAddNodeOpts(opts...)
if len(options.nodeOptions.inputKey) > 0 {
node.fieldMapper = defaultFieldMapper[map[string]any]{}
}

err := wf.gg.AddChatTemplateNode(key, chatTemplate, opts...)
if err != nil {
wf.err = err
return node
Expand All @@ -221,127 +198,167 @@ func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.Ch
return node
}

func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...WorkflowAddNodeOpt) *WorkflowNode {
func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode {
node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[*schema.Message]{}}
wf.nodes[key] = node

if wf.err != nil {
return node
}

err := wf.gg.AddToolsNode(key, tools, convertAddNodeOpts(opts)...)
options := getGraphAddNodeOpts(opts...)
if len(options.nodeOptions.inputKey) > 0 {
node.fieldMapper = defaultFieldMapper[map[string]any]{}
}

err := wf.gg.AddToolsNode(key, tools, opts...)
if err != nil {
wf.err = err
return node
}
return node
}

func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retriever, opts ...WorkflowAddNodeOpt) *WorkflowNode {
func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retriever, opts ...GraphAddNodeOpt) *WorkflowNode {
node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[string]{}}
wf.nodes[key] = node

if wf.err != nil {
return node
}

err := wf.gg.AddRetrieverNode(key, retriever, convertAddNodeOpts(opts)...)
options := getGraphAddNodeOpts(opts...)
if len(options.nodeOptions.inputKey) > 0 {
node.fieldMapper = defaultFieldMapper[map[string]any]{}
}

err := wf.gg.AddRetrieverNode(key, retriever, opts...)
if err != nil {
wf.err = err
return node
}
return node
}

func (wf *Workflow[I, O]) AddEmbeddingNode(key string, embedding embedding.Embedder, opts ...WorkflowAddNodeOpt) *WorkflowNode {
func (wf *Workflow[I, O]) AddEmbeddingNode(key string, embedding embedding.Embedder, opts ...GraphAddNodeOpt) *WorkflowNode {
node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]string]{}}
wf.nodes[key] = node

if wf.err != nil {
return node
}

err := wf.gg.AddEmbeddingNode(key, embedding, convertAddNodeOpts(opts)...)
options := getGraphAddNodeOpts(opts...)
if len(options.nodeOptions.inputKey) > 0 {
node.fieldMapper = defaultFieldMapper[map[string]any]{}
}

err := wf.gg.AddEmbeddingNode(key, embedding, opts...)
if err != nil {
wf.err = err
return node
}
return node
}

func (wf *Workflow[I, O]) AddIndexerNode(key string, indexer indexer.Indexer, opts ...WorkflowAddNodeOpt) *WorkflowNode {
func (wf *Workflow[I, O]) AddIndexerNode(key string, indexer indexer.Indexer, opts ...GraphAddNodeOpt) *WorkflowNode {
node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Document]{}}
wf.nodes[key] = node

if wf.err != nil {
return node
}

err := wf.gg.AddIndexerNode(key, indexer, convertAddNodeOpts(opts)...)
options := getGraphAddNodeOpts(opts...)
if len(options.nodeOptions.inputKey) > 0 {
node.fieldMapper = defaultFieldMapper[map[string]any]{}
}

err := wf.gg.AddIndexerNode(key, indexer, opts...)
if err != nil {
wf.err = err
return node
}
return node
}

func (wf *Workflow[I, O]) AddLoaderNode(key string, loader document.Loader, opts ...WorkflowAddNodeOpt) *WorkflowNode {
func (wf *Workflow[I, O]) AddLoaderNode(key string, loader document.Loader, opts ...GraphAddNodeOpt) *WorkflowNode {
node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[document.Source]{}}
wf.nodes[key] = node

if wf.err != nil {
return node
}

err := wf.gg.AddLoaderNode(key, loader, convertAddNodeOpts(opts)...)
options := getGraphAddNodeOpts(opts...)
if len(options.nodeOptions.inputKey) > 0 {
node.fieldMapper = defaultFieldMapper[map[string]any]{}
}

err := wf.gg.AddLoaderNode(key, loader, opts...)
if err != nil {
wf.err = err
return node
}
return node
}

func (wf *Workflow[I, O]) AddDocumentTransformerNode(key string, transformer document.Transformer, opts ...WorkflowAddNodeOpt) *WorkflowNode {
func (wf *Workflow[I, O]) AddDocumentTransformerNode(key string, transformer document.Transformer, opts ...GraphAddNodeOpt) *WorkflowNode {
node := &WorkflowNode{key: key, fieldMapper: defaultFieldMapper[[]*schema.Document]{}}
wf.nodes[key] = node

if wf.err != nil {
return node
}

err := wf.gg.AddDocumentTransformerNode(key, transformer, convertAddNodeOpts(opts)...)
options := getGraphAddNodeOpts(opts...)
if len(options.nodeOptions.inputKey) > 0 {
node.fieldMapper = defaultFieldMapper[map[string]any]{}
}

err := wf.gg.AddDocumentTransformerNode(key, transformer, opts...)
if err != nil {
wf.err = err
return node
}
return node
}

func (wf *Workflow[I, O]) AddGraphNode(key string, graph AnyGraph, opts ...WorkflowAddNodeOpt) *WorkflowNode {
func (wf *Workflow[I, O]) AddGraphNode(key string, graph AnyGraph, opts ...GraphAddNodeOpt) *WorkflowNode {
node := &WorkflowNode{key: key, fieldMapper: graph.fieldMapper()}
wf.nodes[key] = node

if wf.err != nil {
return node
}

err := wf.gg.AddGraphNode(key, graph, convertAddNodeOpts(opts)...)
options := getGraphAddNodeOpts(opts...)
if len(options.nodeOptions.inputKey) > 0 {
node.fieldMapper = defaultFieldMapper[map[string]any]{}
}

err := wf.gg.AddGraphNode(key, graph, opts...)
if err != nil {
wf.err = err
return node
}
return node
}

func (wf *Workflow[I, O]) AddLambdaNode(key string, lambda *Lambda, opts ...WorkflowAddNodeOpt) *WorkflowNode {
func (wf *Workflow[I, O]) AddLambdaNode(key string, lambda *Lambda, opts ...GraphAddNodeOpt) *WorkflowNode {
node := &WorkflowNode{key: key, fieldMapper: lambda.fieldMapper}
wf.nodes[key] = node

if wf.err != nil {
return node
}

err := wf.gg.AddLambdaNode(key, lambda, convertAddNodeOpts(opts)...)
options := getGraphAddNodeOpts(opts...)
if len(options.nodeOptions.inputKey) > 0 {
node.fieldMapper = defaultFieldMapper[map[string]any]{}
}

err := wf.gg.AddLambdaNode(key, lambda, opts...)
if err != nil {
wf.err = err
return node
Expand Down
Loading

0 comments on commit f0c5bf7

Please sign in to comment.