Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parallel steps #51

Merged
merged 3 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ jobs:
uses: actions/setup-go@v3
with:
go-version: '1.21'
- name: Test
run: go test -v -race -count=1
- name: Unit test
run: go test -v -race -count=1 -short
itest:
strategy:
matrix:
os: [ubuntu-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.21'

# Need npx to start the Dev Server
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: '18'

- name: Integration test
run: make itest
10 changes: 7 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
.PHONY: test
test:
go test -test.v
.PHONY: itest
itest:
go test ./tests -v -count=1

.PHONY: utest
utest:
go test -test.v -short

.PHONY: lint
lint:
Expand Down
50 changes: 50 additions & 0 deletions experimental/group/group.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package group

import (
"context"
"fmt"

"github.com/inngest/inngestgo/step"
)

type Result struct {
Error error
Value any
}

func Parallel(
ctx context.Context,
fns ...func(ctx context.Context,
) (any, error)) []Result {
ctx = context.WithValue(ctx, step.ParallelKey, true)

results := []Result{}
isPlanned := false
ch := make(chan struct{}, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used as a waitgroup?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's effectively a waitgroup of 1. We want to run each step sequentially but also in a goroutine so that we can recover the control hijack

for _, fn := range fns {
fn := fn
go func(fn func(ctx context.Context) (any, error)) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(step.ControlHijack); ok {
isPlanned = true
} else {
// TODO: What to do here?
fmt.Println("TODO")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just repanic for now? I suppose this might be where we capture the non-Inngest panic and could return it as an error to Inngest in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should recover the panic and repanic it outside the goroutine? I think that'll let our normal non-control-flow panic recovery logic work

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh yeah if we already have something that captures panics then that'd be awesome.

}
}
ch <- struct{}{}
}()

value, err := fn(ctx)
results = append(results, Result{Error: err, Value: value})
}(fn)
<-ch
}

if isPlanned {
panic(step.ControlHijack{})
}

return results
}
18 changes: 16 additions & 2 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,13 @@ func (h *handler) invoke(w http.ResponseWriter, r *http.Request) error {
}()
}

var stepID *string
if rawStepID := r.URL.Query().Get("stepId"); rawStepID != "" && rawStepID != "step" {
stepID = &rawStepID
}

// Invoke the function, then immediately stop the streaming buffer.
resp, ops, err := invoke(r.Context(), fn, request)
resp, ops, err := invoke(r.Context(), fn, request, stepID)
streamCancel()

// NOTE: When triggering step errors, we should have an OpcodeStepError
Expand Down Expand Up @@ -779,7 +784,12 @@ type StreamResponse struct {

// invoke calls a given servable function with the specified input event. The input event must
// be fully typed.
func invoke(ctx context.Context, sf ServableFunction, input *sdkrequest.Request) (any, []state.GeneratorOpcode, error) {
func invoke(
ctx context.Context,
sf ServableFunction,
input *sdkrequest.Request,
stepID *string,
) (any, []state.GeneratorOpcode, error) {
if sf.Func() == nil {
// This should never happen, but as sf.Func returns a nillable type we
// must check that the function exists.
Expand All @@ -790,6 +800,10 @@ func invoke(ctx context.Context, sf ServableFunction, input *sdkrequest.Request)
// within a step. This allows us to prevent any execution of future tools after a
// tool has run.
fCtx, cancel := context.WithCancel(context.Background())
if stepID != nil {
fCtx = step.SetTargetStepID(fCtx, *stepID)
}

// This must be a pointer so that it can be mutated from within function tools.
mgr := sdkrequest.NewManager(cancel, input)
fCtx = sdkrequest.SetManager(fCtx, mgr)
Expand Down
10 changes: 5 additions & 5 deletions handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestInvoke(t *testing.T) {
Register(a)

t.Run("it invokes the function with correct types", func(t *testing.T) {
actual, op, err := invoke(ctx, a, createRequest(t, input))
actual, op, err := invoke(ctx, a, createRequest(t, input), nil)
require.NoError(t, err)
require.Nil(t, op)
require.Equal(t, resp, actual)
Expand Down Expand Up @@ -131,7 +131,7 @@ func TestInvoke(t *testing.T) {
Register(a)

t.Run("it invokes the function with correct types", func(t *testing.T) {
actual, op, err := invoke(ctx, a, createBatchRequest(t, input, 5))
actual, op, err := invoke(ctx, a, createBatchRequest(t, input, 5), nil)
require.NoError(t, err)
require.Nil(t, op)
require.Equal(t, resp, actual)
Expand Down Expand Up @@ -166,7 +166,7 @@ func TestInvoke(t *testing.T) {
ctx := context.Background()

t.Run("it invokes the function with correct types", func(t *testing.T) {
actual, op, err := invoke(ctx, a, createRequest(t, input))
actual, op, err := invoke(ctx, a, createRequest(t, input), nil)
require.NoError(t, err)
require.Nil(t, op)
require.Equal(t, resp, actual)
Expand Down Expand Up @@ -204,7 +204,7 @@ func TestInvoke(t *testing.T) {

ctx := context.Background()
t.Run("it invokes the function with correct types", func(t *testing.T) {
actual, op, err := invoke(ctx, a, createRequest(t, input))
actual, op, err := invoke(ctx, a, createRequest(t, input), nil)
require.NoError(t, err)
require.Nil(t, op)
require.Equal(t, resp, actual)
Expand Down Expand Up @@ -241,7 +241,7 @@ func TestInvoke(t *testing.T) {

ctx := context.Background()
t.Run("it invokes the function with correct types", func(t *testing.T) {
actual, op, err := invoke(ctx, a, createRequest(t, input))
actual, op, err := invoke(ctx, a, createRequest(t, input), nil)
require.NoError(t, err)
require.Nil(t, op)
require.Equal(t, resp, actual)
Expand Down
13 changes: 7 additions & 6 deletions internal/sdkrequest/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ type Request struct {
// CallCtx represents context for individual function calls. This logs the function ID, the
// specific run ID, and sep information.
type CallCtx struct {
Env string `json:"env"`
FunctionID string `json:"fn_id"`
RunID string `json:"run_id"`
StepID string `json:"step_id"`
Stack CallStack `json:"stack"`
Attempt int `json:"attempt"`
DisableImmediateExecution bool `json:"disable_immediate_execution"`
Env string `json:"env"`
FunctionID string `json:"fn_id"`
RunID string `json:"run_id"`
StepID string `json:"step_id"`
Stack CallStack `json:"stack"`
Attempt int `json:"attempt"`
}

type CallStack struct {
Expand Down
21 changes: 19 additions & 2 deletions step/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ func Run[T any](
id string,
f func(ctx context.Context) (T, error),
) (T, error) {
targetID := getTargetStepID(ctx)
mgr := preflight(ctx)
op := mgr.NewOp(enums.OpcodeStep, id, nil)
hashedID := op.MustHash()

if val, ok := mgr.Step(op); ok {
// Create a new empty type T in v
Expand Down Expand Up @@ -78,6 +80,21 @@ func Run[T any](
return val, nil
}

if targetID != nil && *targetID != hashedID {
panic(ControlHijack{})
}

planParallel := targetID == nil && isParallel(ctx)
planBeforeRun := targetID == nil && mgr.Request().CallCtx.DisableImmediateExecution
if planParallel || planBeforeRun {
mgr.AppendOp(state.GeneratorOpcode{
ID: hashedID,
Op: enums.OpcodeStepPlanned,
Name: id,
})
panic(ControlHijack{})
}

// We're calling a function, so always cancel the context afterwards so that no
// other tools run.
defer mgr.Cancel()
Expand All @@ -94,7 +111,7 @@ func Run[T any](

// Implement per-step errors.
mgr.AppendOp(state.GeneratorOpcode{
ID: op.MustHash(),
ID: hashedID,
Op: enums.OpcodeStepError,
Name: id,
Error: &state.UserError{
Expand All @@ -112,7 +129,7 @@ func Run[T any](
mgr.SetErr(fmt.Errorf("unable to marshal run respone for '%s': %w", id, err))
}
mgr.AppendOp(state.GeneratorOpcode{
ID: op.MustHash(),
ID: hashedID,
Op: enums.OpcodeStepRun,
Name: id,
Data: byt,
Expand Down
33 changes: 33 additions & 0 deletions step/step.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ import (

type ControlHijack struct{}

type ctxKey string

const (
targetStepIDKey = ctxKey("stepID")
ParallelKey = ctxKey("parallelKey")
)

var (
// ErrNotInFunction is called when a step tool is executed outside of an Inngest
// function call context.
Expand All @@ -23,6 +30,32 @@ func (errNotInFunction) Error() string {
return "step called without function context"
}

func getTargetStepID(ctx context.Context) *string {
if v := ctx.Value(targetStepIDKey); v != nil {
if c, ok := v.(string); ok {
return &c
}
}
return nil
}

func SetTargetStepID(ctx context.Context, id string) context.Context {
if id == "" || id == "step" {
return ctx
}

return context.WithValue(ctx, targetStepIDKey, id)
}

func isParallel(ctx context.Context) bool {
if v := ctx.Value(ParallelKey); v != nil {
if c, ok := v.(bool); ok {
return c
}
}
return false
}

func preflight(ctx context.Context) sdkrequest.InvocationManager {
if ctx.Err() != nil {
// Another tool has already ran and the context is closed. Return
Expand Down
97 changes: 97 additions & 0 deletions tests/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package tests

import (
"fmt"
"net/http"
"os"
"os/exec"
"syscall"
"testing"
"time"

"github.com/inngest/inngestgo"
)

func TestMain(m *testing.M) {
teardown, err := setup()
if err != nil {
fmt.Fprintf(os.Stderr, "failed to setup: %v\n", err)
os.Exit(1)
}

code := m.Run()

err = teardown()
if err != nil {
fmt.Fprintf(os.Stderr, "failed to teardown: %v\n", err)
os.Exit(1)
}

os.Exit(code)
}

func setup() (func() error, error) {
os.Setenv("INNGEST_DEV", "1")

inngestgo.DefaultClient = inngestgo.NewClient(
inngestgo.ClientOpts{
EventKey: inngestgo.StrPtr("dev"),
},
)

// return func() error { return nil }, nil

stopDevServer, err := startDevServer()
if err != nil {
return nil, err
}

return stopDevServer, nil
}

func startDevServer() (func() error, error) {
fmt.Println("Starting Dev Server")
cmd := exec.Command(
"bash",
"-c",
"npx --yes inngest-cli@latest dev --no-discovery --no-poll",
)

// Run in a new process group so we can kill the process and its children
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}

err := cmd.Start()
if err != nil {
return nil, fmt.Errorf("failed to start command: %w", err)
}

// Wait for Dev Server to start
fmt.Println("Waiting for Dev Server to start")
httpClient := http.Client{Timeout: time.Second}
start := time.Now()
for {
resp, err := httpClient.Get("http://0.0.0.0:8288")
if err == nil && resp.StatusCode == 200 {
break
}
if time.Since(start) > 20*time.Second {
return nil, fmt.Errorf("timeout waiting for Dev Server to start: %w", err)
}
<-time.After(500 * time.Millisecond)
}

// Callback to stop the Dev Server
stop := func() error {
fmt.Println("Stopping Dev Server")
pgid, err := syscall.Getpgid(cmd.Process.Pid)
if err != nil {
return fmt.Errorf("failed to get process group ID: %w", err)
}
if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
return fmt.Errorf("failed to kill process group: %w", err)
}
return nil
}

return stop, nil
}
Loading
Loading