Skip to content

Commit

Permalink
refactor: use GetState with ReAct agent
Browse files Browse the repository at this point in the history
Change-Id: Idcea4add7a1d60054089078002ba87856757b0d5
  • Loading branch information
shentongmartin committed Dec 19, 2024
1 parent eccf2c4 commit baa9290
Showing 1 changed file with 82 additions and 133 deletions.
215 changes: 82 additions & 133 deletions flow/agent/react/react.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ package react

import (
"context"
"errors"
"fmt"
"reflect"
"io"

"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/compose"
Expand All @@ -32,10 +31,15 @@ type nodeState struct {
Messages []*schema.Message
}

const (
nodeKeyTools = "tools"
nodeKeyChatModel = "chat"
)

// MessageModifier modify the input messages before the model is called.
type MessageModifier func(ctx context.Context, input []*schema.Message) []*schema.Message

// AgentConfig is the config for react agent.
// AgentConfig is the config for ReAct agent.
type AgentConfig struct {
// Model is the chat model to be used for handling user messages.
Model model.ChatModel
Expand Down Expand Up @@ -77,28 +81,34 @@ func NewPersonaModifier(persona string) MessageModifier {
}
}

// NewAgent creates a react agent.
// NewAgent creates a ReAct agent that feeds tool response into next round of Chat Model generation.
func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
agent := &Agent{}
if config.MessageModifier == nil {
config.MessageModifier = func(ctx context.Context, input []*schema.Message) []*schema.Message {
return input
}
}

a := &Agent{}

runnable, err := agent.build(ctx, config)
runnable, err := a.build(ctx, config)
if err != nil {
return nil, err
}

agent.runnable = runnable
a.runnable = runnable

return agent, nil
return a, nil
}

// Agent is the react agent.
// React agent is a simple agent that handles user messages with a chat model and tools.
// react will call the chat model, if the message contains tool calls, it will call the tools.
// if the tool is configured to return directly, react will return directly.
// otherwise, react will continue to call the chat model until the message contains no tool calls.
// eg.
// Agent is the ReAct agent.
// ReAct agent is a simple agent that handles user messages with a chat model and tools.
// ReAct will call the chat model, if the message contains tool calls, it will call the tools.
// if the tool is configured to return directly, ReAct will return directly.
// otherwise, ReAct will continue to call the chat model until the message contains no tool calls.
// e.g.
//
// agent, err := react.NewAgent(ctx, &react.AgentConfig{})
// agent, err := ReAct.NewAgent(ctx, &ReAct.AgentConfig{})
// if err != nil {...}
// msg, err := agent.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "how to build agent with eino"}})
// if err != nil {...}
Expand All @@ -108,17 +118,6 @@ type Agent struct {
}

func (r *Agent) build(ctx context.Context, config *AgentConfig) (compose.Runnable[[]*schema.Message, *schema.Message], error) {
var (
nodeKeyTools = "tools"
nodeKeyChatModel = "chat"
)

if config.MessageModifier == nil {
config.MessageModifier = func(ctx context.Context, input []*schema.Message) []*schema.Message {
return input
}
}

toolInfos := make([]*schema.ToolInfo, 0, len(config.ToolsConfig.Tools))
for _, t := range config.ToolsConfig.Tools {
tl, err := t.Info(ctx)
Expand All @@ -138,10 +137,9 @@ func (r *Agent) build(ctx context.Context, config *AgentConfig) (compose.Runnabl
graph := compose.NewGraph[[]*schema.Message, *schema.Message](
compose.WithGenLocalState(
func(ctx context.Context) *nodeState {
s := &nodeState{
return &nodeState{
Messages: make([]*schema.Message, 0, 3),
}
return s
}))

err = graph.AddChatModelNode(nodeKeyChatModel, config.Model,
Expand All @@ -167,13 +165,7 @@ func (r *Agent) build(ctx context.Context, config *AgentConfig) (compose.Runnabl
err = graph.AddToolsNode(nodeKeyTools, toolsNode, compose.WithStatePreHandler(func(ctx context.Context, input *schema.Message, state *nodeState) (*schema.Message, error) {
state.Messages = append(state.Messages, input)

if len(config.ToolReturnDirectly) > 0 {
if err := checkReturnDirectlyBeforeToolsNode(input, config); err != nil {
return nil, err
}
}

if err := cacheToolCallInfo(ctx, input.ToolCalls); err != nil {
if err := checkReturnDirectlyBeforeToolsNode(input, config); err != nil {
return nil, err
}

Expand All @@ -183,17 +175,17 @@ func (r *Agent) build(ctx context.Context, config *AgentConfig) (compose.Runnabl
return nil, err
}

err = graph.AddEdge(compose.START, nodeKeyChatModel)
if err != nil {
if err = graph.AddEdge(compose.START, nodeKeyChatModel); err != nil {
return nil, err
}

err = graph.AddBranch(nodeKeyChatModel, compose.NewStreamGraphBranch(func(ctx context.Context, sr *schema.StreamReader[*schema.Message]) (endNode string, err error) {
defer sr.Close()

msg, err := sr.Recv()
if err != nil {
return "", err
}
defer sr.Close()

if len(msg.ToolCalls) == 0 {
return compose.END, nil
Expand All @@ -206,57 +198,7 @@ func (r *Agent) build(ctx context.Context, config *AgentConfig) (compose.Runnabl
}

if len(config.ToolReturnDirectly) > 0 {
returnDirectlyConvertor := func(ctx context.Context, msgs *schema.StreamReader[[]*schema.Message]) (*schema.StreamReader[*schema.Message], error) {
flattened := schema.StreamReaderWithConvert(msgs, func(msgs []*schema.Message) (*schema.Message, error) {
if len(msgs) != 1 {
return nil, fmt.Errorf("return directly tools node output expected to have only one msg, but got %d", len(msgs))
}
return msgs[0], nil
})

return flattened, nil
}

nodeKeyConvertor := "convertor"
err = graph.AddLambdaNode(nodeKeyConvertor, compose.TransformableLambda(returnDirectlyConvertor))
if err != nil {
return nil, err
}

err = graph.AddBranch(nodeKeyTools, compose.NewStreamGraphBranch(func(ctx context.Context, msgsStream *schema.StreamReader[[]*schema.Message]) (endNode string, err error) {
defer msgsStream.Close()

msgs, err := msgsStream.Recv()
if err != nil {
return "", fmt.Errorf("receive first packet from tools node result returns err: %w", err)
}

if len(msgs) == 0 {
return "", errors.New("receive first package from tools node result returns empty msgs")
}

msg := msgs[0]
toolCallID := msg.ToolCallID
if len(toolCallID) == 0 {
return "", errors.New("receive first package from tools node result returns empty tool call id")
}

toolCall, err := getToolCallInfo(ctx, toolCallID)
if err != nil {
return "", fmt.Errorf("get tool call info for tool call id: %s returns err: %w", toolCallID, err)
}

if _, ok := config.ToolReturnDirectly[toolCall.Function.Name]; ok { // return directly will appear in first message
return nodeKeyConvertor, nil
}

return nodeKeyChatModel, nil
}, map[string]bool{nodeKeyChatModel: true, nodeKeyConvertor: true}))
if err != nil {
return nil, err
}

if err = graph.AddEdge(nodeKeyConvertor, compose.END); err != nil {
if err = r.buildReturnDirectly(graph, config); err != nil {
return nil, err
}
} else {
Expand All @@ -270,61 +212,74 @@ func (r *Agent) build(ctx context.Context, config *AgentConfig) (compose.Runnabl
opts = append(opts, compose.WithMaxRunSteps(config.MaxStep))
}

runnable, err := graph.Compile(ctx, opts...)
if err != nil {
return nil, err
}

return runnable, nil
return graph.Compile(ctx, opts...)
}

type toolCallInfoKey struct{}

func cacheToolCallInfo(ctx context.Context, toolCalls []schema.ToolCall) error {
info := ctx.Value(toolCallInfoKey{})
if info == nil {
return errors.New("tool call info not found in context")
func (r *Agent) buildReturnDirectly(graph *compose.Graph[[]*schema.Message, *schema.Message], config *AgentConfig) (err error) {
takeFirst := func(ctx context.Context, msgs *schema.StreamReader[[]*schema.Message]) (*schema.StreamReader[*schema.Message], error) {
return schema.StreamReaderWithConvert(msgs, func(msgs []*schema.Message) (*schema.Message, error) {
if len(msgs) != 1 {
return nil, fmt.Errorf("return directly tools node output expected to have only one msg, but got %d", len(msgs))
}
return msgs[0], nil
}), nil
}

toolCallInfo, ok := info.(*map[string]schema.ToolCall)
if !ok {
return fmt.Errorf("tool call info type error, not atomic.Value: %v", reflect.TypeOf(info))
nodeKeyTakeFirst := "convertor" // convert output of tools node ([]*schema.Message) to a single *schema.Message, so that it could be returned directly
if err = graph.AddLambdaNode(nodeKeyTakeFirst, compose.TransformableLambda(takeFirst)); err != nil {
return err
}

m := make(map[string]schema.ToolCall, len(toolCalls))
for i := range toolCalls {
m[toolCalls[i].ID] = toolCalls[i]
}
// this branch checks if the tool called should return directly. It either leads to END or back to ChatModel
err = graph.AddBranch(nodeKeyTools, compose.NewStreamGraphBranch(func(ctx context.Context, msgsStream *schema.StreamReader[[]*schema.Message]) (endNode string, err error) {
state, err := compose.GetState[*nodeState](ctx) // last msg stored in state should contain the tool call information
if err != nil {
return "", fmt.Errorf("get nodeState in branch failed: %w", err)
}

*toolCallInfo = m
defer msgsStream.Close()

return nil
}
for {
msgs, err := msgsStream.Recv()
if err != nil {
if err == io.EOF {
return nodeKeyChatModel, nil
}
return "", fmt.Errorf("receive first packet from tools node result returns err: %w", err)
}

func getToolCallInfo(ctx context.Context, toolCallID string) (*schema.ToolCall, error) {
info := ctx.Value(toolCallInfoKey{})
if info == nil {
return nil, errors.New("tool call info not found in context")
}
if len(msgs) == 0 {
continue
}

toolCallInfo, ok := info.(*map[string]schema.ToolCall)
if !ok {
return nil, fmt.Errorf("tool call info type error, not map[string]schema.ToolCall: %v", reflect.TypeOf(info))
}
toolCallID := msgs[0].ToolCallID
if len(toolCallID) == 0 {
continue
}

if toolCallInfo == nil {
return nil, errors.New("tool call info is nil")
}
for _, toolCall := range state.Messages[len(state.Messages)-1].ToolCalls {
if toolCall.ID == toolCallID {
if _, ok := config.ToolReturnDirectly[toolCall.Function.Name]; ok {
return nodeKeyTakeFirst, nil
}
}
}

toolCall, ok := (*toolCallInfo)[toolCallID]
if !ok {
return nil, fmt.Errorf("tool call info not found for tool call id: %s", toolCallID)
return nodeKeyChatModel, nil
}
}, map[string]bool{nodeKeyChatModel: true, nodeKeyTakeFirst: true}))
if err != nil {
return err
}

return &toolCall, nil
return graph.AddEdge(nodeKeyTakeFirst, compose.END)
}

func checkReturnDirectlyBeforeToolsNode(input *schema.Message, config *AgentConfig) error {
if len(config.ToolReturnDirectly) == 0 {
return nil
}

if len(input.ToolCalls) > 1 { // check if a return directly tool call belongs to a batch of parallel tool calls, which is not supported for now
var returnDirectly bool
toolCalls := input.ToolCalls
Expand All @@ -347,9 +302,6 @@ func checkReturnDirectlyBeforeToolsNode(input *schema.Message, config *AgentConf

// Generate generates a response from the agent.
func (r *Agent) Generate(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (output *schema.Message, err error) {
m := make(map[string]schema.ToolCall, 0)
ctx = context.WithValue(ctx, toolCallInfoKey{}, &m)

output, err = r.runnable.Invoke(ctx, input, agent.GetComposeOptions(opts...)...)
if err != nil {
return nil, err
Expand All @@ -361,9 +313,6 @@ func (r *Agent) Generate(ctx context.Context, input []*schema.Message, opts ...a
// Stream calls the agent and returns a stream response.
func (r *Agent) Stream(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (
output *schema.StreamReader[*schema.Message], err error) {
m := make(map[string]schema.ToolCall, 0)
ctx = context.WithValue(ctx, toolCallInfoKey{}, &m)

res, err := r.runnable.Stream(ctx, input, agent.GetComposeOptions(opts...)...)
if err != nil {
return nil, err
Expand Down

0 comments on commit baa9290

Please sign in to comment.