Skip to content

Commit

Permalink
refactor: callbacks
Browse files Browse the repository at this point in the history
-- move some code into internal/callbacks
-- rm deprecated code
-- refactor duplicated code

Change-Id: I443b3074426bf406c4ad68ad027f1f9e143b288f
  • Loading branch information
luohq-bytedance committed Dec 19, 2024
1 parent b44844b commit a01006a
Show file tree
Hide file tree
Showing 13 changed files with 415 additions and 723 deletions.
123 changes: 11 additions & 112 deletions callbacks/aspect_inject.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package callbacks
import (
"context"

"github.com/cloudwego/eino/internal/callbacks"
"github.com/cloudwego/eino/schema"
)

Expand Down Expand Up @@ -48,41 +49,20 @@ import (
// return resp, nil
// }
//

// OnStart invokes the OnStart logic for the particular context, ensuring that all registered
// handlers are executed in reverse order (compared to add order) when a process begins.
func OnStart(ctx context.Context, input CallbackInput) context.Context {
mgr, ok := managerFromCtx(ctx)
if !ok {
return ctx
}

for i := len(mgr.handlers) - 1; i >= 0; i-- {
handler := mgr.handlers[i]
timingChecker, ok := handler.(TimingChecker)
if !ok || timingChecker.Needed(ctx, mgr.runInfo, TimingOnStart) {
ctx = handler.OnStart(ctx, mgr.runInfo, input)
}
}
func OnStart[T any](ctx context.Context, input T) context.Context {
ctx, _ = callbacks.On(ctx, input, callbacks.OnStartHandle[T], TimingOnStart)

return ctx
}

// OnEnd invokes the OnEnd logic of the particular context, allowing for proper cleanup
// and finalization when a process ends.
// handlers are executed in normal order (compared to add order).
func OnEnd(ctx context.Context, output CallbackOutput) context.Context {
mgr, ok := managerFromCtx(ctx)
if !ok {
return ctx
}

for i := 0; i < len(mgr.handlers); i++ {
handler := mgr.handlers[i]
timingChecker, ok := handler.(TimingChecker)
if !ok || timingChecker.Needed(ctx, mgr.runInfo, TimingOnEnd) {
ctx = handler.OnEnd(ctx, mgr.runInfo, output)
}
}
func OnEnd[T any](ctx context.Context, output T) context.Context {
ctx, _ = callbacks.On(ctx, output, callbacks.OnEndHandle[T], TimingOnEnd)

return ctx
}
Expand All @@ -93,37 +73,7 @@ func OnEnd(ctx context.Context, output CallbackOutput) context.Context {
func OnStartWithStreamInput[T any](ctx context.Context, input *schema.StreamReader[T]) (
nextCtx context.Context, newStreamReader *schema.StreamReader[T]) {

mgr, ok := managerFromCtx(ctx)
if !ok {
return ctx, input
}

if len(mgr.handlers) == 0 {
return ctx, input
}

var neededHandlers []Handler
for i := range mgr.handlers {
h := mgr.handlers[i]
timingChecker, ok := h.(TimingChecker)
if !ok || timingChecker.Needed(ctx, mgr.runInfo, TimingOnStartWithStreamInput) {
neededHandlers = append(neededHandlers, h)
}
}

if len(neededHandlers) == 0 {
return ctx, input
}

cp := input.Copy(len(neededHandlers) + 1)
for i := len(neededHandlers) - 1; i >= 0; i-- {
h := neededHandlers[i]
ctx = h.OnStartWithStreamInput(ctx, mgr.runInfo, schema.StreamReaderWithConvert(cp[i], func(src T) (CallbackInput, error) {
return src, nil
}))
}

return ctx, cp[len(cp)-1]
return callbacks.On(ctx, input, callbacks.OnStartWithStreamInputHandle[T], TimingOnStartWithStreamInput)
}

// OnEndWithStreamOutput invokes the OnEndWithStreamOutput logic of the particular, ensuring that
Expand All @@ -132,75 +82,24 @@ func OnStartWithStreamInput[T any](ctx context.Context, input *schema.StreamRead
func OnEndWithStreamOutput[T any](ctx context.Context, output *schema.StreamReader[T]) (
nextCtx context.Context, newStreamReader *schema.StreamReader[T]) {

mgr, ok := managerFromCtx(ctx)
if !ok {
return ctx, output
}

if len(mgr.handlers) == 0 {
return ctx, output
}

var neededHandlers []Handler
for i := range mgr.handlers {
h := mgr.handlers[i]
timingChecker, ok := h.(TimingChecker)
if !ok || timingChecker.Needed(ctx, mgr.runInfo, TimingOnEndWithStreamOutput) {
neededHandlers = append(neededHandlers, h)
}
}

if len(neededHandlers) == 0 {
return ctx, output
}

cp := output.Copy(len(neededHandlers) + 1)
for i := 0; i < len(neededHandlers); i++ {
h := neededHandlers[i]
ctx = h.OnEndWithStreamOutput(ctx, mgr.runInfo, schema.StreamReaderWithConvert(cp[i], func(src T) (CallbackOutput, error) {
return src, nil
}))
}

return ctx, cp[len(cp)-1]
return callbacks.On(ctx, output, callbacks.OnEndWithStreamOutputHandle[T], TimingOnEndWithStreamOutput)
}

// OnError invokes the OnError logic of the particular, notice that error in stream will not represent here.
// handlers are executed in normal order (compared to add order).
func OnError(ctx context.Context, err error) context.Context {
mgr, ok := managerFromCtx(ctx)
if !ok {
return ctx
}

for i := 0; i < len(mgr.handlers); i++ {
handler := mgr.handlers[i]
timingChecker, ok := handler.(TimingChecker)
if !ok || timingChecker.Needed(ctx, mgr.runInfo, TimingOnError) {
ctx = handler.OnError(ctx, mgr.runInfo, err)
}
}
ctx, _ = callbacks.On(ctx, err, callbacks.OnErrorHandle, TimingOnError)

return ctx
}

// SetRunInfo sets the RunInfo to be passed to Handler.
func SetRunInfo(ctx context.Context, info *RunInfo) context.Context {
cbm, ok := managerFromCtx(ctx)
if !ok {
return ctx
}

return ctxWithManager(ctx, cbm.withRunInfo(info))
return callbacks.SetRunInfo(ctx, info)
}

// InitCallbacks initializes a new context with the provided RunInfo and handlers.
// Any previously set RunInfo and Handlers for this ctx will be overwritten.
func InitCallbacks(ctx context.Context, info *RunInfo, handlers ...Handler) context.Context {
mgr, ok := newManager(info, handlers...)
if ok {
return ctxWithManager(ctx, mgr)
}

return ctxWithManager(ctx, nil)
return callbacks.InitCallbacks(ctx, info, handlers...)
}
4 changes: 1 addition & 3 deletions callbacks/aspect_inject_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,8 @@ func TestAspectInject(t *testing.T) {
return ctx
}).Build()

manager, ok := newManager(nil, hb)
assert.True(t, ok)
ctx = InitCallbacks(ctx, nil, hb)

ctx = ctxWithManager(ctx, manager)
ctx = OnStart(ctx, 1)
ctx = OnEnd(ctx, 2)
ctx = OnError(ctx, fmt.Errorf("3"))
Expand Down
145 changes: 39 additions & 106 deletions callbacks/handler_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,124 +22,43 @@ import (
"github.com/cloudwego/eino/schema"
)

// HandlerBuilder can be used to build a Handler with callback functions.
// e.g.
//
// handler := &HandlerBuilder{
// OnStartFn: func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context {} // self defined start callback function
// }
//
// graph := compose.NewGraph[inputType, outputType]()
// runnable, err := graph.Compile()
// if err != nil {...}
// runnable.Invoke(ctx, params, compose.WithCallback(handler)) // => only implement functions which you want to override
//
// Deprecated: In most situations, it is preferred to use callbacks.NewHandlerHelper. Otherwise, use NewHandlerBuilder().OnStartFn()...Build().
type HandlerBuilder struct {
OnStartFn func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context
OnEndFn func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context
OnErrorFn func(ctx context.Context, info *RunInfo, err error) context.Context
OnStartWithStreamInputFn func(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context
OnEndWithStreamOutputFn func(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context
}

func (h *HandlerBuilder) OnStart(ctx context.Context, info *RunInfo, input CallbackInput) context.Context {
if h.OnStartFn != nil {
return h.OnStartFn(ctx, info, input)
}

return ctx
}

func (h *HandlerBuilder) OnEnd(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context {
if h.OnEndFn != nil {
return h.OnEndFn(ctx, info, output)
}

return ctx
}

func (h *HandlerBuilder) OnError(ctx context.Context, info *RunInfo, err error) context.Context {
if h.OnErrorFn != nil {
return h.OnErrorFn(ctx, info, err)
}

return ctx
}

func (h *HandlerBuilder) OnStartWithStreamInput(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context {
if h.OnStartWithStreamInputFn != nil {
return h.OnStartWithStreamInputFn(ctx, info, input)
}

input.Close()

return ctx
}

func (h *HandlerBuilder) OnEndWithStreamOutput(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context {
if h.OnEndWithStreamOutputFn != nil {
return h.OnEndWithStreamOutputFn(ctx, info, output)
}

output.Close()

return ctx
}

type handlerBuilder struct {
onStartFn func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context
onEndFn func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context
onErrorFn func(ctx context.Context, info *RunInfo, err error) context.Context
onStartWithStreamInputFn func(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context
onEndWithStreamOutputFn func(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context
}

func (hb *handlerBuilder) OnStart(ctx context.Context, info *RunInfo, input CallbackInput) context.Context {
if hb.onStartFn != nil {
return hb.onStartFn(ctx, info, input)
}

return ctx
type handlerImpl struct {
HandlerBuilder
}

func (hb *handlerBuilder) OnEnd(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context {
if hb.onEndFn != nil {
return hb.onEndFn(ctx, info, output)
}

return ctx
func (hb *handlerImpl) OnStart(ctx context.Context, info *RunInfo, input CallbackInput) context.Context {
return hb.onStartFn(ctx, info, input)
}

func (hb *handlerBuilder) OnError(ctx context.Context, info *RunInfo, err error) context.Context {
if hb.onErrorFn != nil {
return hb.onErrorFn(ctx, info, err)
}

return ctx
func (hb *handlerImpl) OnEnd(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context {
return hb.onEndFn(ctx, info, output)
}

func (hb *handlerBuilder) OnStartWithStreamInput(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context {
if hb.onStartWithStreamInputFn != nil {
return hb.onStartWithStreamInputFn(ctx, info, input)
}
func (hb *handlerImpl) OnError(ctx context.Context, info *RunInfo, err error) context.Context {
return hb.onErrorFn(ctx, info, err)
}

input.Close()
func (hb *handlerImpl) OnStartWithStreamInput(ctx context.Context, info *RunInfo,
input *schema.StreamReader[CallbackInput]) context.Context {

return ctx
return hb.onStartWithStreamInputFn(ctx, info, input)
}

func (hb *handlerBuilder) OnEndWithStreamOutput(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context {
if hb.onEndWithStreamOutputFn != nil {
return hb.onEndWithStreamOutputFn(ctx, info, output)
}
func (hb *handlerImpl) OnEndWithStreamOutput(ctx context.Context, info *RunInfo,
output *schema.StreamReader[CallbackOutput]) context.Context {

output.Close()

return ctx
return hb.onEndWithStreamOutputFn(ctx, info, output)
}

func (hb *handlerBuilder) Needed(_ context.Context, _ *RunInfo, timing CallbackTiming) bool {
func (hb *handlerImpl) Needed(_ context.Context, _ *RunInfo, timing CallbackTiming) bool {
switch timing {
case TimingOnStart:
return hb.onStartFn != nil
Expand All @@ -156,36 +75,50 @@ func (hb *handlerBuilder) Needed(_ context.Context, _ *RunInfo, timing CallbackT
}
}

func NewHandlerBuilder() *handlerBuilder {
return &handlerBuilder{}
// NewHandlerBuilder creates and returns a new HandlerBuilder instance.
// HandlerBuilder is used to construct a Handler with custom callback functions
func NewHandlerBuilder() *HandlerBuilder {
return &HandlerBuilder{}
}

func (hb *handlerBuilder) OnStartFn(fn func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context) *handlerBuilder {
func (hb *HandlerBuilder) OnStartFn(
fn func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context) *HandlerBuilder {

hb.onStartFn = fn
return hb
}

func (hb *handlerBuilder) OnEndFn(fn func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context) *handlerBuilder {
func (hb *HandlerBuilder) OnEndFn(
fn func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context) *HandlerBuilder {

hb.onEndFn = fn
return hb
}

func (hb *handlerBuilder) OnErrorFn(fn func(ctx context.Context, info *RunInfo, err error) context.Context) *handlerBuilder {
func (hb *HandlerBuilder) OnErrorFn(
fn func(ctx context.Context, info *RunInfo, err error) context.Context) *HandlerBuilder {

hb.onErrorFn = fn
return hb
}

func (hb *handlerBuilder) OnStartWithStreamInputFn(fn func(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context) *handlerBuilder {
// OnStartWithStreamInputFn sets the callback function to be called.
func (hb *HandlerBuilder) OnStartWithStreamInputFn(
fn func(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context) *HandlerBuilder {

hb.onStartWithStreamInputFn = fn
return hb
}

func (hb *handlerBuilder) OnEndWithStreamOutputFn(fn func(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context) *handlerBuilder {
// OnEndWithStreamOutputFn sets the callback function to be called.
func (hb *HandlerBuilder) OnEndWithStreamOutputFn(
fn func(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context) *HandlerBuilder {

hb.onEndWithStreamOutputFn = fn
return hb
}

// Build returns a Handler with the functions set in the builder.
func (hb *handlerBuilder) Build() Handler {
return hb
func (hb *HandlerBuilder) Build() Handler {
return &handlerImpl{*hb}
}
Loading

0 comments on commit a01006a

Please sign in to comment.