Skip to content

Commit

Permalink
Add parallel steps
Browse files Browse the repository at this point in the history
  • Loading branch information
amh4r committed Aug 6, 2024
1 parent 20a5229 commit 5d29934
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 44 deletions.
65 changes: 31 additions & 34 deletions examples/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import (
"context"
"fmt"
"net/http"
"time"

"github.com/inngest/inngestgo"
"github.com/inngest/inngestgo/experimental/group"
"github.com/inngest/inngestgo/step"
)

Expand All @@ -32,45 +32,42 @@ func main() {
// Function state is automatically managed, and persists across server restarts,
// cloud migrations, and language changes.
func AccountCreated(ctx context.Context, input inngestgo.Input[AccountCreatedEvent]) (any, error) {
// Sleep for a second, minute, hour, week across server restarts.
step.Sleep(ctx, "initial-delay", time.Second)

// Run a step which emails the user. This automatically retries on error.
// This returns the fully typed result of the lambda.
result, err := step.Run(ctx, "on-user-created", func(ctx context.Context) (bool, error) {
// Run any code inside a step.
return false, nil
})
// `result` is fully typed from the lambda
_ = result

// Sample from the event stream for new events. The function will stop
// running and automatially resume when a matching event is found, or if
// the timeout is reached.
fn, err := step.WaitForEvent[FunctionCreatedEvent](
results := group.Parallel(
ctx,
"wait-for-activity",
step.WaitForEventOpts{
Name: "Wait for a function to be created",
Event: "api/function.created",
Timeout: time.Hour * 72,
// Match events where the user_id is the same in the async sampled event.
If: inngestgo.StrPtr("event.data.user_id == async.data.user_id"),
func(ctx context.Context) (any, error) {
return step.Run(ctx, "a", func(ctx context.Context) (string, error) {
fmt.Println("Running step a")
return "A", nil
})
},
func(ctx context.Context) (any, error) {
return step.Run(ctx, "b", func(ctx context.Context) (string, error) {
fmt.Println("Running step b")
return "B", nil
})
},
)
if err == step.ErrEventNotReceived {
// A function wasn't created within 3 days. Send a follow-up email.
_, _ = step.Run(ctx, "follow-up-email", func(ctx context.Context) (any, error) {
// ...
return true, nil
})
return nil, nil

parallelOutputs := make([]string, len(results))
for i, r := range results {
if r.Error != nil {
return nil, r.Error
}
parallelOutputs[i] = r.Value.(string)
}

// The event returned from `step.WaitForEvent` is fully typed.
fmt.Println(fn.Data.FunctionID)
singleOutput, err := step.Run(ctx, "c", func(ctx context.Context) (string, error) {
fmt.Println("Running step c")
return "C", nil
})
if err != nil {
return nil, err
}

return nil, nil
return map[string]any{
"parallelOutputs": parallelOutputs,
"singleOutput": singleOutput,
}, nil
}

// AccountCreatedEvent represents the fully defined event received when an account is created.
Expand Down
50 changes: 50 additions & 0 deletions experimental/group/group.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package group

import (
"context"
"sync"

"github.com/inngest/inngestgo/step"
)

type result struct {
Error error
Value any
}

func Parallel(
ctx context.Context,
fns ...func(ctx context.Context,
) (any, error)) []result {
ctx = context.WithValue(ctx, step.ParallelKey, true)

results := []result{}
isPlanned := false
wg := sync.WaitGroup{}
wg.Add(len(fns))
for _, fn := range fns {
fn := fn
go func(fn func(ctx context.Context) (any, error)) {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
if _, ok := r.(step.ControlHijack); ok {
isPlanned = true
} else {

Check failure on line 33 in experimental/group/group.go

View workflow job for this annotation

GitHub Actions / lint (ubuntu-latest)

SA9003: empty branch (staticcheck)
// TODO: What to do here?
}
}
}()

value, err := fn(ctx)
results = append(results, result{Error: err, Value: value})
}(fn)
}
wg.Wait()

if isPlanned {
panic(step.ControlHijack{})
}

return results
}
18 changes: 16 additions & 2 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,13 @@ func (h *handler) invoke(w http.ResponseWriter, r *http.Request) error {
}()
}

var stepID *string
if rawStepID := r.URL.Query().Get("stepId"); rawStepID != "" && rawStepID != "step" {
stepID = &rawStepID
}

// Invoke the function, then immediately stop the streaming buffer.
resp, ops, err := invoke(r.Context(), fn, request)
resp, ops, err := invoke(r.Context(), fn, request, stepID)
streamCancel()

// NOTE: When triggering step errors, we should have an OpcodeStepError
Expand Down Expand Up @@ -779,7 +784,12 @@ type StreamResponse struct {

// invoke calls a given servable function with the specified input event. The input event must
// be fully typed.
func invoke(ctx context.Context, sf ServableFunction, input *sdkrequest.Request) (any, []state.GeneratorOpcode, error) {
func invoke(
ctx context.Context,
sf ServableFunction,
input *sdkrequest.Request,
stepID *string,
) (any, []state.GeneratorOpcode, error) {
if sf.Func() == nil {
// This should never happen, but as sf.Func returns a nillable type we
// must check that the function exists.
Expand All @@ -790,6 +800,10 @@ func invoke(ctx context.Context, sf ServableFunction, input *sdkrequest.Request)
// within a step. This allows us to prevent any execution of future tools after a
// tool has run.
fCtx, cancel := context.WithCancel(context.Background())
if stepID != nil {
fCtx = step.SetTargetStepID(fCtx, *stepID)
}

// This must be a pointer so that it can be mutated from within function tools.
mgr := sdkrequest.NewManager(cancel, input)
fCtx = sdkrequest.SetManager(fCtx, mgr)
Expand Down
13 changes: 7 additions & 6 deletions internal/sdkrequest/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ type Request struct {
// CallCtx represents context for individual function calls. This logs the function ID, the
// specific run ID, and sep information.
type CallCtx struct {
Env string `json:"env"`
FunctionID string `json:"fn_id"`
RunID string `json:"run_id"`
StepID string `json:"step_id"`
Stack CallStack `json:"stack"`
Attempt int `json:"attempt"`
DisableImmediateExecution bool `json:"disable_immediate_execution"`
Env string `json:"env"`
FunctionID string `json:"fn_id"`
RunID string `json:"run_id"`
StepID string `json:"step_id"`
Stack CallStack `json:"stack"`
Attempt int `json:"attempt"`
}

type CallStack struct {
Expand Down
22 changes: 20 additions & 2 deletions step/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ func Run[T any](
id string,
f func(ctx context.Context) (T, error),
) (T, error) {
targetID := getTargetStepID(ctx)

mgr := preflight(ctx)
op := mgr.NewOp(enums.OpcodeStep, id, nil)
hashedID := op.MustHash()

if val, ok := mgr.Step(op); ok {
// Create a new empty type T in v
Expand Down Expand Up @@ -78,6 +81,21 @@ func Run[T any](
return val, nil
}

if targetID != nil && *targetID != hashedID {
panic(ControlHijack{})
}

planParallel := targetID == nil && isParallel(ctx)
planBeforeRun := targetID == nil && mgr.Request().CallCtx.DisableImmediateExecution
if planParallel || planBeforeRun {
mgr.AppendOp(state.GeneratorOpcode{
ID: hashedID,
Op: enums.OpcodeStepPlanned,
Name: id,
})
panic(ControlHijack{})
}

// We're calling a function, so always cancel the context afterwards so that no
// other tools run.
defer mgr.Cancel()
Expand All @@ -94,7 +112,7 @@ func Run[T any](

// Implement per-step errors.
mgr.AppendOp(state.GeneratorOpcode{
ID: op.MustHash(),
ID: hashedID,
Op: enums.OpcodeStepError,
Name: id,
Error: &state.UserError{
Expand All @@ -112,7 +130,7 @@ func Run[T any](
mgr.SetErr(fmt.Errorf("unable to marshal run respone for '%s': %w", id, err))
}
mgr.AppendOp(state.GeneratorOpcode{
ID: op.MustHash(),
ID: hashedID,
Op: enums.OpcodeStepRun,
Name: id,
Data: byt,
Expand Down
33 changes: 33 additions & 0 deletions step/step.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ import (

type ControlHijack struct{}

type ctxKey string

const (
targetStepIDKey = ctxKey("stepID")
ParallelKey = ctxKey("parallelKey")
)

var (
// ErrNotInFunction is called when a step tool is executed outside of an Inngest
// function call context.
Expand All @@ -23,6 +30,32 @@ func (errNotInFunction) Error() string {
return "step called without function context"
}

func getTargetStepID(ctx context.Context) *string {
if v := ctx.Value(targetStepIDKey); v != nil {
if c, ok := v.(string); ok {
return &c
}
}
return nil
}

func SetTargetStepID(ctx context.Context, id string) context.Context {
if id == "" || id == "step" {
return ctx
}

return context.WithValue(ctx, targetStepIDKey, id)
}

func isParallel(ctx context.Context) bool {
if v := ctx.Value(ParallelKey); v != nil {
if c, ok := v.(bool); ok {
return c
}
}
return false
}

func preflight(ctx context.Context) sdkrequest.InvocationManager {
if ctx.Err() != nil {
// Another tool has already ran and the context is closed. Return
Expand Down

0 comments on commit 5d29934

Please sign in to comment.