From bf9fe3d616b37dd18a3f67d7071e056de10ae628 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Mon, 23 Dec 2024 19:37:14 +0800 Subject: [PATCH] feat: expose graph name in GraphInfo (#17) Change-Id: I945424ec5fdbd0307330aa0bf0b746f71f11a3b8 --- compose/graph.go | 1 + compose/graph_test.go | 8 +++++--- compose/introspect.go | 1 + 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/compose/graph.go b/compose/graph.go index 728ed7d..c6edb51 100644 --- a/compose/graph.go +++ b/compose/graph.go @@ -842,6 +842,7 @@ func (g *graph) toGraphInfo(opt *graphCompileOptions, key2SubGraphs map[string]* }), InputType: g.expectedInputType, OutputType: g.expectedOutputType, + Name: opt.graphName, } for key := range g.nodes { diff --git a/compose/graph_test.go b/compose/graph_test.go index baa1eb3..7cc5587 100644 --- a/compose/graph_test.go +++ b/compose/graph_test.go @@ -1038,7 +1038,7 @@ func TestGraphCompileCallback(t *testing.T) { t.Run("graph compile callback", func(t *testing.T) { type s struct{} - g := NewStateGraph[map[string]any, map[string]any, *s](func(ctx context.Context) *s { return &s{} }) + g := NewGraph[map[string]any, map[string]any](WithGenLocalState(func(ctx context.Context) *s { return &s{} })) lambda := InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "node1", nil @@ -1085,7 +1085,7 @@ func TestGraphCompileCallback(t *testing.T) { err = subGraph.AddEdge("sub_sub_1", END) assert.NoError(t, err) - subGraphCompileOpts := []GraphCompileOption{WithMaxRunSteps(2)} + subGraphCompileOpts := []GraphCompileOption{WithMaxRunSteps(2), WithGraphName("sub_graph")} subGraphOpts := []GraphAddNodeOpt{WithGraphCompileOptions(subGraphCompileOpts...)} err = g.AddGraphNode("sub_graph", subGraph, subGraphOpts...) assert.NoError(t, err) @@ -1119,7 +1119,7 @@ func TestGraphCompileCallback(t *testing.T) { assert.NoError(t, err) c := &cb{} - opt := []GraphCompileOption{WithGraphCompileCallbacks(c)} + opt := []GraphCompileOption{WithGraphCompileCallbacks(c), WithGraphName("top_level")} _, err = g.Compile(context.Background(), opt...) assert.NoError(t, err) expected := &GraphInfo{ @@ -1192,6 +1192,7 @@ func TestGraphCompileCallback(t *testing.T) { Branches: map[string][]GraphBranch{}, InputType: reflect.TypeOf(""), OutputType: reflect.TypeOf(""), + Name: "sub_graph", }, }, "node3": { @@ -1226,6 +1227,7 @@ func TestGraphCompileCallback(t *testing.T) { }, InputType: reflect.TypeOf(map[string]any{}), OutputType: reflect.TypeOf(map[string]any{}), + Name: "top_level", } stateFn := c.gInfo.GenStateFn diff --git a/compose/introspect.go b/compose/introspect.go index 12ae2b5..3c29357 100644 --- a/compose/introspect.go +++ b/compose/introspect.go @@ -43,6 +43,7 @@ type GraphInfo struct { Edges map[string][]string // edge start node key -> edge end node key Branches map[string][]GraphBranch // branch start node key -> branch InputType, OutputType reflect.Type + Name string GenStateFn func(context.Context) any }