Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nexus callback and task validation fixes #6918

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 66 additions & 41 deletions components/callbacks/executors.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,37 +58,67 @@ func RegisterExecutor(
)
}

type (
TaskExecutorOptions struct {
fx.In

Config *Config
NamespaceRegistry namespace.Registry
MetricsHandler metrics.Handler
Logger log.Logger
HTTPCallerProvider HTTPCallerProvider
HistoryClient resource.HistoryClient
}
type TaskExecutorOptions struct {
fx.In

Config *Config
NamespaceRegistry namespace.Registry
MetricsHandler metrics.Handler
Logger log.Logger
HTTPCallerProvider HTTPCallerProvider
HistoryClient resource.HistoryClient
}

taskExecutor struct {
TaskExecutorOptions
}
type taskExecutor struct {
TaskExecutorOptions
}

invocationResult int
// invocationResult is a marker for the callbackInvokable.Invoke result to indicate to the executor how to handle the
// invocation outcome.
type invocationResult interface {
// A marker for all possible implementations.
mustImplementInvocationResult()
error() error
}

callbackInvokable interface {
// Invoke executes the callback logic and returns a result, and the error to be logged in the state machine.
Invoke(ctx context.Context, ns *namespace.Namespace, e taskExecutor, task InvocationTask) (invocationResult, error)
// WrapError provides each variant the opportunity to return a different error up the call stack than the one logged.
WrapError(result invocationResult, err error) error
}
)
// invocationResultFail marks an invocation as successful.
type invocationResultOK struct{}

const (
ok invocationResult = iota
retry
failed
)
func (invocationResultOK) mustImplementInvocationResult() {}
rodrigozhou marked this conversation as resolved.
Show resolved Hide resolved

func (invocationResultOK) error() error {
return nil
}

// invocationResultFail marks an invocation as permanently failed.
type invocationResultFail struct {
err error
}

func (invocationResultFail) mustImplementInvocationResult() {}

func (r invocationResultFail) error() error {
return r.err
}

// invocationResultRetry marks an invocation as failed with the intent to retry.
type invocationResultRetry struct {
err error
}

func (invocationResultRetry) mustImplementInvocationResult() {}

func (r invocationResultRetry) error() error {
return r.err
}

type callbackInvokable interface {
// Invoke executes the callback logic and returns the invocation result.
Invoke(ctx context.Context, ns *namespace.Namespace, e taskExecutor, task InvocationTask) invocationResult
// WrapError provides each variant the opportunity to wrap the error returned by the task executor for, e.g. to
// trigger the circuit breaker.
WrapError(result invocationResult, err error) error
}

func (e taskExecutor) executeInvocationTask(
ctx context.Context,
Expand All @@ -112,13 +142,9 @@ func (e taskExecutor) executeInvocationTask(
)
defer cancel()

result, err := invokable.Invoke(callCtx, ns, e, task)

saveErr := e.saveResult(callCtx, env, ref, result, err)
if saveErr != nil {
return saveErr
}
return invokable.WrapError(result, err)
result := invokable.Invoke(callCtx, ns, e, task)
saveErr := e.saveResult(callCtx, env, ref, result)
return invokable.WrapError(result, saveErr)
}

func (e taskExecutor) loadInvocationArgs(
Expand Down Expand Up @@ -174,25 +200,24 @@ func (e taskExecutor) saveResult(
env hsm.Environment,
ref hsm.Ref,
result invocationResult,
callErr error,
) error {
return env.Access(ctx, ref, hsm.AccessWrite, func(node *hsm.Node) error {
return hsm.MachineTransition(node, func(callback Callback) (hsm.TransitionOutput, error) {
switch result {
case ok:
switch result.(type) {
case invocationResultOK:
return TransitionSucceeded.Apply(callback, EventSucceeded{
Time: env.Now(),
})
case retry:
case invocationResultRetry:
return TransitionAttemptFailed.Apply(callback, EventAttemptFailed{
Time: env.Now(),
Err: callErr,
Err: result.error(),
RetryPolicy: e.Config.RetryPolicy(),
})
case failed:
case invocationResultFail:
return TransitionFailed.Apply(callback, EventFailed{
Time: env.Now(),
Err: callErr,
Err: result.error(),
})
default:
return hsm.TransitionOutput{}, queues.NewUnprocessableTaskError(fmt.Sprintf("unrecognized callback result %v", result))
Expand Down
10 changes: 5 additions & 5 deletions components/callbacks/hsm_invocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ func (s hsmInvocation) WrapError(invocationResult, error) error {
return nil
}

func (s hsmInvocation) Invoke(ctx context.Context, ns *namespace.Namespace, e taskExecutor, task InvocationTask) (invocationResult, error) {
func (s hsmInvocation) Invoke(ctx context.Context, ns *namespace.Namespace, e taskExecutor, task InvocationTask) invocationResult {
// TODO(Tianyu): Will this ever be too big for an RPC call?
callbackArgSerialized, err := s.callbackArg.Marshal()
if err != nil {
return failed, fmt.Errorf("failed to serialize completion event: %v", err)
return invocationResultFail{fmt.Errorf("failed to serialize completion event: %w", err)}
}

request := historyservice.InvokeStateMachineMethodRequest{
Expand All @@ -107,9 +107,9 @@ func (s hsmInvocation) Invoke(ctx context.Context, ns *namespace.Namespace, e ta
if err != nil {
e.Logger.Error("Callback request failed", tag.Error(err))
if isRetryableRpcResponse(err) {
return retry, err
return invocationResultRetry{err}
}
return failed, err
return invocationResultFail{err}
}
return ok, nil
return invocationResultOK{}
}
24 changes: 10 additions & 14 deletions components/callbacks/nexus_invocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,18 @@ func outcomeTag(callCtx context.Context, response *http.Response, callErr error)
}

func (n nexusInvocation) WrapError(result invocationResult, err error) error {
// If the request permanently failed there is no need to raise the error
if result == failed {
return nil
if failure, ok := result.(invocationResultRetry); ok {
return queues.NewDestinationDownError(failure.err.Error(), err)
}
if err != nil {
return queues.NewDestinationDownError(err.Error(), err)
}
return nil
return err
}

func (n nexusInvocation) Invoke(ctx context.Context, ns *namespace.Namespace, e taskExecutor, task InvocationTask) (invocationResult, error) {
func (n nexusInvocation) Invoke(ctx context.Context, ns *namespace.Namespace, e taskExecutor, task InvocationTask) invocationResult {
request, err := nexus.NewCompletionHTTPRequest(ctx, n.nexus.Url, n.completion)
if err != nil {
return failed, queues.NewUnprocessableTaskError(
return invocationResultFail{queues.NewUnprocessableTaskError(
fmt.Sprintf("failed to construct Nexus request: %v", err),
)
)}
}
if request.Header == nil {
request.Header = make(http.Header)
Expand Down Expand Up @@ -116,17 +112,17 @@ func (n nexusInvocation) Invoke(ctx context.Context, ns *namespace.Namespace, e

if err != nil {
e.Logger.Error("Callback request failed with error", tag.Error(err))
return retry, err
return invocationResultRetry{err}
}
if response.StatusCode >= 200 && response.StatusCode < 300 {
return ok, nil
return invocationResultOK{}
}

retryable := isRetryableHTTPResponse(response)
err = fmt.Errorf("request failed with: %v", response.Status)
e.Logger.Error("Callback request failed", tag.Error(err), tag.NewStringTag("status", response.Status), tag.NewBoolTag("retryable", retryable))
if retryable {
return retry, err
return invocationResultRetry{err}
}
return failed, err
return invocationResultFail{err}
}
5 changes: 3 additions & 2 deletions components/callbacks/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ package callbacks
import (
"time"

enumsspb "go.temporal.io/server/api/enums/v1"
persistencespb "go.temporal.io/server/api/persistence/v1"
"go.temporal.io/server/service/history/hsm"
)
Expand Down Expand Up @@ -59,7 +60,7 @@ func (t InvocationTask) Deadline() time.Time {
}

func (InvocationTask) Validate(ref *persistencespb.StateMachineRef, node *hsm.Node) error {
return hsm.ValidateNotTransitioned(ref, node)
return hsm.ValidateState[enumsspb.CallbackState, Callback](node, enumsspb.CALLBACK_STATE_SCHEDULED)
}

type InvocationTaskSerializer struct{}
Expand Down Expand Up @@ -91,7 +92,7 @@ func (BackoffTask) Destination() string {
}

func (BackoffTask) Validate(ref *persistencespb.StateMachineRef, node *hsm.Node) error {
return hsm.ValidateNotTransitioned(ref, node)
return hsm.ValidateState[enumsspb.CallbackState, Callback](node, enumsspb.CALLBACK_STATE_BACKING_OFF)
}

type BackoffTaskSerializer struct{}
Expand Down
37 changes: 8 additions & 29 deletions components/nexusoperations/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import (
"fmt"
"time"

enumspb "go.temporal.io/server/api/enums/v1"
enumspb "go.temporal.io/api/enums/v1"
enumsspb "go.temporal.io/server/api/enums/v1"
persistencespb "go.temporal.io/server/api/persistence/v1"
"go.temporal.io/server/service/history/consts"
"go.temporal.io/server/service/history/hsm"
Expand Down Expand Up @@ -110,18 +111,7 @@ func (InvocationTask) Validate(ref *persistencespb.StateMachineRef, node *hsm.No
if err := node.CheckRunning(); err != nil {
return err
}
op, err := hsm.MachineData[Operation](node)
if err != nil {
return err
}
if op.State() != enumspb.NEXUS_OPERATION_STATE_SCHEDULED {
return fmt.Errorf(
"%w: operation is not in Scheduled state, current state: %v",
consts.ErrStaleReference,
op.State(),
)
}
return nil
return hsm.ValidateState[enumsspb.NexusOperationState, Operation](node, enumsspb.NEXUS_OPERATION_STATE_SCHEDULED)
}

type InvocationTaskSerializer struct{}
Expand Down Expand Up @@ -156,18 +146,7 @@ func (t BackoffTask) Validate(_ *persistencespb.StateMachineRef, node *hsm.Node)
if err := node.CheckRunning(); err != nil {
return err
}
op, err := hsm.MachineData[Operation](node)
if err != nil {
return err
}
if op.State() != enumspb.NEXUS_OPERATION_STATE_BACKING_OFF {
return fmt.Errorf(
"%w: operation is not in BackingOff state, current state: %v",
consts.ErrStaleReference,
op.State(),
)
}
return nil
return hsm.ValidateState[enumsspb.NexusOperationState, Operation](node, enumsspb.NEXUS_OPERATION_STATE_BACKING_OFF)
}

type BackoffTaskSerializer struct{}
Expand Down Expand Up @@ -199,10 +178,10 @@ func (t CancelationTask) Destination() string {
}

func (CancelationTask) Validate(ref *persistencespb.StateMachineRef, node *hsm.Node) error {
if err := hsm.ValidateNotTransitioned(ref, node); err != nil {
if err := node.CheckRunning(); err != nil {
return err
}
return node.CheckRunning()
return hsm.ValidateState[enumspb.NexusOperationCancellationState, Cancelation](node, enumspb.NEXUS_OPERATION_CANCELLATION_STATE_SCHEDULED)
}

type CancelationTaskSerializer struct{}
Expand Down Expand Up @@ -234,10 +213,10 @@ func (CancelationBackoffTask) Destination() string {
}

func (CancelationBackoffTask) Validate(ref *persistencespb.StateMachineRef, node *hsm.Node) error {
if err := hsm.ValidateNotTransitioned(ref, node); err != nil {
if err := node.CheckRunning(); err != nil {
return err
}
return node.CheckRunning()
return hsm.ValidateState[enumspb.NexusOperationCancellationState, Cancelation](node, enumspb.NEXUS_OPERATION_CANCELLATION_STATE_BACKING_OFF)
}

type CancelationBackoffTaskSerializer struct{}
Expand Down
21 changes: 20 additions & 1 deletion service/history/hsm/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,26 @@ type TaskSerializer interface {
// generated.
func ValidateNotTransitioned(ref *persistencespb.StateMachineRef, node *Node) error {
if ref.MachineTransitionCount != node.InternalRepr().TransitionCount {
return fmt.Errorf("%w: state machine transitions != ref transitions", consts.ErrStaleReference)
return fmt.Errorf("%w: state machine transitions (%d) != ref transitions (%d)", consts.ErrStaleReference, node.InternalRepr().TransitionCount, ref.MachineTransitionCount)
}
return nil
}

// ValidateState returns a [consts.ErrStaleReference] if the machine is not in the expected state.
func ValidateState[S comparable, T StateMachine[S]](node *Node, expected S) error {
cb, err := MachineData[T](node)
if err != nil {
return err
}
if cb.State() != expected {
return fmt.Errorf(
"%w: %w: expected a %s machine in %v state, got %v",
consts.ErrStaleReference,
ErrInvalidTransition,
node.Key.ID,
expected,
cb.State(),
)
}
return nil
}
Loading
Loading