Skip to content

Commit

Permalink
fix: designate call option to subgraph correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
meguminnnnnnnnn committed Dec 20, 2024
1 parent 313de1b commit 486a6fa
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 63 deletions.
48 changes: 37 additions & 11 deletions compose/graph_call_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,49 @@ type Option struct {
options []any
handler []callbacks.Handler

keys []string
paths [][]string

maxRunSteps int
}

func (o Option) deepCopy() Option {
nOptions := make([]any, len(o.options))
copy(nOptions, o.options)
nHandler := make([]callbacks.Handler, len(o.handler))
copy(nHandler, o.handler)
nPaths := make([][]string, len(o.paths))
for i, path := range o.paths {
nPaths[i] = make([]string, len(path))
copy(nPaths[i], path)
}
return Option{
options: nOptions,
handler: nHandler,
paths: nPaths,
maxRunSteps: o.maxRunSteps,
}
}

// DesignateNode set the key of the node which will be used to.
// eg.
// e.g.
//
// embeddingOption := compose.WithEmbeddingOption(embedding.WithModel("text-embedding-3-small"))
// runnable.Invoke(ctx, "input", embeddingOption.DesignateNode("my_embedding_node"))
func (o Option) DesignateNode(key ...string) Option {
o.keys = append(o.keys, key...)
nKeys := make([][]string, len(key))
for i, k := range key {
nKeys[i] = []string{k}
}
return o.DesignateNodeWithPath(nKeys...)
}

func (o Option) DesignateNodeWithPath(path ...[]string) Option {
o.paths = append(o.paths, path...)
return o
}

// WithEmbeddingOption is a functional option type for embedding component.
// eg.
// e.g.
//
// embeddingOption := compose.WithEmbeddingOption(embedding.WithModel("text-embedding-3-small"))
// runnable.Invoke(ctx, "input", embeddingOption)
Expand All @@ -59,7 +85,7 @@ func WithEmbeddingOption(opts ...embedding.Option) Option {
}

// WithRetrieverOption is a functional option type for retriever component.
// eg.
// e.g.
//
// retrieverOption := compose.WithRetrieverOption(retriever.WithIndex("my_index"))
// runnable.Invoke(ctx, "input", retrieverOption)
Expand All @@ -73,7 +99,7 @@ func WithLoaderSplitterOption(opts ...document.LoaderSplitterOption) Option {
}

// WithLoaderOption is a functional option type for loader component.
// eg.
// e.g.
//
// loaderOption := compose.WithLoaderOption(document.WithCollection("my_collection"))
// runnable.Invoke(ctx, "input", loaderOption)
Expand All @@ -87,7 +113,7 @@ func WithDocumentTransformerOption(opts ...document.TransformerOption) Option {
}

// WithIndexerOption is a functional option type for indexer component.
// eg.
// e.g.
//
// indexerOption := compose.WithIndexerOption(indexer.WithSubIndexes([]string{"my_sub_index"}))
// runnable.Invoke(ctx, "input", indexerOption)
Expand All @@ -96,7 +122,7 @@ func WithIndexerOption(opts ...indexer.Option) Option {
}

// WithChatModelOption is a functional option type for chat model component.
// eg.
// e.g.
//
// chatModelOption := compose.WithChatModelOption(model.WithTemperature(0.7))
// runnable.Invoke(ctx, "input", chatModelOption)
Expand All @@ -118,12 +144,12 @@ func WithToolsNodeOption(opts ...ToolsNodeOption) Option {
func WithLambdaOption(opts ...any) Option {
return Option{
options: opts,
keys: make([]string, 0),
paths: make([][]string, 0),
}
}

// WithCallbacks set callback handlers for all components in a single call.
// eg.
// e.g.
//
// runnable.Invoke(ctx, "input", compose.WithCallbacks(&myCallbacks{}))
func WithCallbacks(cbs ...callbacks.Handler) Option {
Expand Down Expand Up @@ -154,7 +180,7 @@ func withComponentOption[TOption any](opts ...TOption) Option {
}
return Option{
options: o,
keys: make([]string, 0),
paths: make([][]string, 0),
}
}

Expand Down
194 changes: 194 additions & 0 deletions compose/graph_call_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,197 @@ func TestCallOptionsOneByOne(t *testing.T) {
assert.Equal(t, int64(123), opt.uid)
})
}

func TestCallOptionInSubGraph(t *testing.T) {
ctx := context.Background()

type child1Option string
type child2Option string
type parentOption string
type grandparentOption string

child1 := NewGraph[string, string]()
err := child1.AddLambdaNode("1", InvokableLambdaWithOption(func(ctx context.Context, input string, opts ...child1Option) (output string, err error) {
if len(opts) != 1 || opts[0] != "child1-1" {
t.Fatal("child1-1 option error")
}
return input + " child1-1", nil
}), WithNodeName("child1-1"))
assert.NoError(t, err)
err = child1.AddEdge(START, "1")
assert.NoError(t, err)
err = child1.AddEdge("1", END)
assert.NoError(t, err)

child2 := NewGraph[string, string]()
err = child2.AddLambdaNode("1", InvokableLambdaWithOption(func(ctx context.Context, input string, opts ...child2Option) (output string, err error) {
if len(opts) != 1 || opts[0] != "child2-1" {
t.Fatal("child2-1 option error")
}
return input + " child2-1", nil
}), WithNodeName("child2-1"))
assert.NoError(t, err)
err = child2.AddEdge(START, "1")
assert.NoError(t, err)
err = child2.AddEdge("1", END)
assert.NoError(t, err)

parent := NewGraph[string, string]()
err = parent.AddLambdaNode("1", InvokableLambdaWithOption(func(ctx context.Context, input string, opts ...parentOption) (output string, err error) {
if len(opts) != 1 || opts[0] != "parent-1" {
t.Fatal("parent-1 option error")
}
return input + " parent-1", nil
}), WithNodeName("parent-1"))
assert.NoError(t, err)
err = parent.AddGraphNode("2", child1, WithNodeName("child1"))
assert.NoError(t, err)
err = parent.AddGraphNode("3", child2, WithNodeName("child2"))
assert.NoError(t, err)
err = parent.AddEdge(START, "1")
assert.NoError(t, err)
err = parent.AddEdge("1", "2")
assert.NoError(t, err)
err = parent.AddEdge("2", "3")
assert.NoError(t, err)
err = parent.AddEdge("3", END)
assert.NoError(t, err)

grandParent := NewGraph[string, string]()
err = grandParent.AddLambdaNode("1", InvokableLambdaWithOption(func(ctx context.Context, input string, opts ...grandparentOption) (output string, err error) {
if len(opts) != 1 || opts[0] != "grandparent-1" {
t.Fatal("grandparent-1 option error")
}
return input + " grandparent-1", nil
}), WithNodeName("grandparent-1"))
assert.NoError(t, err)
err = grandParent.AddGraphNode("2", parent, WithNodeName("parent"))
assert.NoError(t, err)
err = grandParent.AddEdge(START, "1")
assert.NoError(t, err)
err = grandParent.AddEdge("1", "2")
assert.NoError(t, err)
err = grandParent.AddEdge("2", END)
assert.NoError(t, err)

r, err := grandParent.Compile(ctx, WithGraphName("grandparent"))
assert.NoError(t, err)

grandCommonTimes := 0
grandCommonCB := callbacks.NewHandlerBuilder().OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
switch grandCommonTimes {
case 0:
if info.Name != "grandparent" || info.Component != ComponentOfGraph {
t.Fatal("grandparent common callback 0 error")
}
case 1:
if info.Name != "grandparent-1" {
t.Fatal("grandparent common callback 1 error")
}
case 2:
if info.Name != "parent" {
t.Fatal("grandparent common callback 2 error")
}
case 3:
if info.Name != "parent-1" {
t.Fatal("grandparent common callback 3 error")
}
case 4:
if info.Name != "child1" {
t.Fatal("grandparent common callback 4 error")
}
case 5:
if info.Name != "child1-1" {
t.Fatal("grandparent common callback 5 error")
}
case 6:
if info.Name != "child2" {
t.Fatal("grandparent common callback 6 error")
}
case 7:
if info.Name != "child2-1" {
t.Fatal("grandparent common callback 7 error")
}
default:
t.Fatal("grandparent common callback too many")
}
grandCommonTimes++
return ctx
}).Build()
grand1CB := callbacks.NewHandlerBuilder().OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
if info.Name != "grandparent-1" {
t.Fatal("grandparent common callback 0 error")
}
return ctx
}).Build()
parentCommonCBTimes := 0
parentCommonCB := callbacks.NewHandlerBuilder().OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
switch parentCommonCBTimes {
case 0:
if info.Name != "parent" {
t.Fatal("parent common callback 0 error")
}
case 1:
if info.Name != "parent-1" {
t.Fatal("parent common callback 1 error")
}
case 2:
if info.Name != "child1" {
t.Fatal("parent common callback 2 error")
}
case 3:
if info.Name != "child1-1" {
t.Fatal("parent common callback 3 error")
}
case 4:
if info.Name != "child2" {
t.Fatal("parent common callback 4 error")
}
case 5:
if info.Name != "child2-1" {
t.Fatal("parent common callback 5 error")
}
default:
t.Fatal("parent common callback too many")
}
parentCommonCBTimes++
return ctx
}).Build()
child1CommonCBTimes := 0
child1CommonCB := callbacks.NewHandlerBuilder().OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
switch child1CommonCBTimes {
case 0:
if info.Name != "child1" {
t.Fatal("child1 common callback 0 error")
}
case 1:
if info.Name != "child1-1" {
t.Fatal("child1 common callback 1 error")
}
default:
t.Fatal("child1 common callback too many")
}
child1CommonCBTimes++
return ctx
}).Build()
child2CB := callbacks.NewHandlerBuilder().OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
if info.Name != "child2-1" {
t.Fatal("child2-1 common callback 0 error")
}
return ctx
}).Build()

result, err := r.Invoke(ctx, "input",
WithCallbacks(grandCommonCB),
WithCallbacks(parentCommonCB).DesignateNodeWithPath([]string{"2"}),
WithCallbacks(grand1CB).DesignateNode("1"),
WithCallbacks(child1CommonCB).DesignateNodeWithPath([]string{"2", "2"}),
WithCallbacks(child2CB).DesignateNodeWithPath([]string{"2", "3", "1"}),
WithLambdaOption(grandparentOption("grandparent-1")).DesignateNodeWithPath([]string{"1"}),
WithLambdaOption(parentOption("parent-1")).DesignateNodeWithPath([]string{"2", "1"}),
WithLambdaOption(child1Option("child1-1")).DesignateNodeWithPath([]string{"2", "2", "1"}),
WithLambdaOption(child2Option("child2-1")).DesignateNodeWithPath([]string{"2", "3", "1"}),
)
assert.NoError(t, err)
assert.Equal(t, result, "input grandparent-1 parent-1 child1-1 child2-1")
}
31 changes: 0 additions & 31 deletions compose/graph_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"runtime/debug"
"sync"

"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/schema"
"github.com/cloudwego/eino/utils/safe"
)
Expand Down Expand Up @@ -496,33 +495,3 @@ func (r *runner) parserOrValidateTypeIfNeeded(cur, next string, isStream bool, v
}
return value, nil
}

func initNodeCallbacks(ctx context.Context, key string, info *nodeInfo, meta *executorMeta, opts ...Option) context.Context {
ri := &callbacks.RunInfo{}
if meta != nil {
ri.Component = meta.component
ri.Type = meta.componentImplType
}

if info != nil {
ri.Name = info.name
}

var cbs []callbacks.Handler
for i := range opts {
if len(opts[i].handler) != 0 {
if len(opts[i].keys) == 0 {
cbs = append(cbs, opts[i].handler...)
} else {
for _, k := range opts[i].keys {
if k == key {
cbs = append(cbs, opts[i].handler...)
break
}
}
}
}
}

return callbacks.InitCallbacks(ctx, ri, cbs...)
}
Loading

0 comments on commit 486a6fa

Please sign in to comment.