diff --git a/compose/graph.go b/compose/graph.go index 728ed7d..5e5fca5 100644 --- a/compose/graph.go +++ b/compose/graph.go @@ -213,7 +213,7 @@ func (g *graph) component() component { } func isChain(cmp component) bool { - return cmp == ComponentOfChain || cmp == ComponentOfStateChain + return cmp == ComponentOfChain } // ErrGraphCompiled is returned when attempting to modify a graph after it has been compiled diff --git a/compose/graph_test.go b/compose/graph_test.go index baa1eb3..19df65d 100644 --- a/compose/graph_test.go +++ b/compose/graph_test.go @@ -1038,7 +1038,7 @@ func TestGraphCompileCallback(t *testing.T) { t.Run("graph compile callback", func(t *testing.T) { type s struct{} - g := NewStateGraph[map[string]any, map[string]any, *s](func(ctx context.Context) *s { return &s{} }) + g := NewGraph[map[string]any, map[string]any](WithGenLocalState(func(ctx context.Context) *s { return &s{} })) lambda := InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "node1", nil diff --git a/compose/state.go b/compose/state.go index e317fee..5f376d5 100644 --- a/compose/state.go +++ b/compose/state.go @@ -25,72 +25,6 @@ import ( "github.com/cloudwego/eino/utils/generic" ) -// NewStateGraph creates a new state graph. It requires a func of GenLocalState to generate the state. -// -// Deprecated: NewStateGraph is deprecated and will be removed in a future version. -// Use NewGraph with WithGenLocalState option instead: -// -// // Instead of: -// graph := NewStateGraph[Input, Output, State](genStateFunc) -// -// // Use: -// graph := NewGraph[Input, Output](WithGenLocalState(genStateFunc)) -func NewStateGraph[I, O, S any](gen GenLocalState[S]) *StateGraph[I, O, S] { - sg := &StateGraph[I, O, S]{NewGraph[I, O](WithGenLocalState(gen))} - - sg.cmp = ComponentOfStateGraph - - return sg -} - -// StateGraph is a graph that shares state between nodes. It's useful when you want to share some data across nodes. -// -// Deprecated: StateGraph is deprecated and will be removed in a future version. -// Use Graph with WithGenLocalState option instead: -// -// // Instead of: -// graph := NewStateGraph[Input, Output, State](genStateFunc) -// -// // Use: -// graph := NewGraph[Input, Output](WithGenLocalState(genStateFunc)) -type StateGraph[I, O, S any] struct { - *Graph[I, O] -} - -// NewStateChain creates a new state chain. It requires a func of GenLocalState to generate the state. -// -// Deprecated: NewStateChain is deprecated and will be removed in a future version. -// Use NewChain with WithGenLocalState option instead: -// -// // Instead of: -// chain := NewStateChain[Input, Output, State](genStateFunc) -// -// // Use: -// chain := NewChain[Input, Output](WithGenLocalState(genStateFunc)) -func NewStateChain[I, O, S any](gen GenLocalState[S]) *StateChain[I, O, S] { - sc := &StateChain[I, O, S]{NewChain[I, O](WithGenLocalState(gen))} - - sc.gg.cmp = ComponentOfStateChain - - return sc -} - -// StateChain is a chain that shares state between nodes. State is shared between nodes in the chain. -// It's useful when you want to share some data across nodes in a chain. -// you can use WithPreHandler and WithPostHandler to do something with state of this chain. -// -// Deprecated: StateChain is deprecated and will be removed in a future version. -// Use Chain with WithGenLocalState option instead: -// -// // Instead of: -// chain := NewStateChain[Input, Output, State](genStateFunc) -// -// // Use: -// chain := NewChain[Input, Output](WithGenLocalState(genStateFunc)) -type StateChain[I, O, S any] struct { - *Chain[I, O] -} - // GenLocalState is a function that generates the state. type GenLocalState[S any] func(ctx context.Context) (state S) diff --git a/compose/state_test.go b/compose/state_test.go index 2f0fae8..203e2d9 100644 --- a/compose/state_test.go +++ b/compose/state_test.go @@ -48,7 +48,7 @@ func TestStateGraphWithEdge(t *testing.T) { return &testState{} } - sg := NewStateGraph[string, string, *testState](gen) + sg := NewGraph[string, string](WithGenLocalState(gen)) l1 := InvokableLambda(func(ctx context.Context, in string) (out midStr, err error) { return midStr("InvokableLambda: " + in), nil @@ -221,9 +221,9 @@ func TestStateChain(t *testing.T) { Field1 string Field2 string } - sc := NewStateChain[string, string, *testState](func(ctx context.Context) (state *testState) { + sc := NewChain[string, string](WithGenLocalState(func(ctx context.Context) (state *testState) { return &testState{} - }) + })) r, err := sc.AppendLambda(InvokableLambda(func(ctx context.Context, input string) (output string, err error) { s, err := GetState[*testState](ctx) @@ -259,7 +259,7 @@ func TestStreamState(t *testing.T) { } ctx := context.Background() s := &testState{Field1: "1"} - g := NewStateGraph[string, string, *testState](func(ctx context.Context) (state *testState) { return s }) + g := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) (state *testState) { return s })) err := g.AddLambdaNode("1", TransformableLambda(func(ctx context.Context, input *schema.StreamReader[string]) (output *schema.StreamReader[string], err error) { return input, nil }), WithStreamStatePreHandler(func(ctx context.Context, in *schema.StreamReader[string], state *testState) (*schema.StreamReader[string], error) { diff --git a/compose/types.go b/compose/types.go index 1697966..1d1297f 100644 --- a/compose/types.go +++ b/compose/types.go @@ -27,9 +27,7 @@ type component = components.Component const ( ComponentOfUnknown component = "Unknown" ComponentOfGraph component = "Graph" - ComponentOfStateGraph component = "StateGraph" ComponentOfChain component = "Chain" - ComponentOfStateChain component = "StateChain" ComponentOfPassthrough component = "Passthrough" ComponentOfToolsNode component = "ToolsNode" ComponentOfLambda component = "Lambda" diff --git a/utils/callbacks/template.go b/utils/callbacks/template.go index 6b42459..3170150 100644 --- a/utils/callbacks/template.go +++ b/utils/callbacks/template.go @@ -296,7 +296,6 @@ func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo, ctx = c.toolHandler.OnError(ctx, info, err) } case compose.ComponentOfGraph, - compose.ComponentOfStateGraph, compose.ComponentOfChain, compose.ComponentOfPassthrough, compose.ComponentOfToolsNode,