diff --git a/flytepropeller/pkg/controller/config/config.go b/flytepropeller/pkg/controller/config/config.go index 10a83cd2fc..1afc986287 100644 --- a/flytepropeller/pkg/controller/config/config.go +++ b/flytepropeller/pkg/controller/config/config.go @@ -112,9 +112,10 @@ var ( EventConfig: EventConfig{ RawOutputPolicy: RawOutputPolicyReference, }, - ClusterID: "propeller", - CreateFlyteWorkflowCRD: false, - ArrayNodeEventVersion: 0, + ClusterID: "propeller", + CreateFlyteWorkflowCRD: false, + ArrayNodeEventVersion: 0, + NodeExecutionWorkerCount: 8, } ) @@ -155,6 +156,7 @@ type Config struct { ClusterID string `json:"cluster-id" pflag:",Unique cluster id running this flytepropeller instance with which to annotate execution events"` CreateFlyteWorkflowCRD bool `json:"create-flyteworkflow-crd" pflag:",Enable creation of the FlyteWorkflow CRD on startup"` ArrayNodeEventVersion int `json:"array-node-event-version" pflag:",ArrayNode eventing version. 0 => legacy (drop-in replacement for maptask), 1 => new"` + NodeExecutionWorkerCount int `json:"node-execution-worker-count" pflag:",Number of workers to evaluate node executions, currently only used for array nodes"` } // KubeClientConfig contains the configuration used by flytepropeller to configure its internal Kubernetes Client. diff --git a/flytepropeller/pkg/controller/config/config_flags.go b/flytepropeller/pkg/controller/config/config_flags.go index 8e9c71bcdb..07a4fba742 100755 --- a/flytepropeller/pkg/controller/config/config_flags.go +++ b/flytepropeller/pkg/controller/config/config_flags.go @@ -108,5 +108,6 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "cluster-id"), defaultConfig.ClusterID, "Unique cluster id running this flytepropeller instance with which to annotate execution events") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "create-flyteworkflow-crd"), defaultConfig.CreateFlyteWorkflowCRD, "Enable creation of the FlyteWorkflow CRD on startup") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "array-node-event-version"), defaultConfig.ArrayNodeEventVersion, "ArrayNode eventing version. 0 => legacy (drop-in replacement for maptask), 1 => new") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "node-execution-worker-count"), defaultConfig.NodeExecutionWorkerCount, "Number of workers to evaluate node executions, currently only used for array nodes") return cmdFlags } diff --git a/flytepropeller/pkg/controller/config/config_flags_test.go b/flytepropeller/pkg/controller/config/config_flags_test.go index f48d01ebea..54da9e9fe1 100755 --- a/flytepropeller/pkg/controller/config/config_flags_test.go +++ b/flytepropeller/pkg/controller/config/config_flags_test.go @@ -911,4 +911,18 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_node-execution-worker-count", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("node-execution-worker-count", testValue) + if vInt, err := cmdFlags.GetInt("node-execution-worker-count"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.NodeExecutionWorkerCount) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/flytepropeller/pkg/controller/nodes/array/execution_context.go b/flytepropeller/pkg/controller/nodes/array/execution_context.go index 2191b9c7d2..4fb5a8a214 100644 --- a/flytepropeller/pkg/controller/nodes/array/execution_context.go +++ b/flytepropeller/pkg/controller/nodes/array/execution_context.go @@ -14,35 +14,24 @@ const ( type arrayExecutionContext struct { executors.ExecutionContext - executionConfig v1alpha1.ExecutionConfig - currentParallelism *uint32 + executionConfig v1alpha1.ExecutionConfig } func (a *arrayExecutionContext) GetExecutionConfig() v1alpha1.ExecutionConfig { return a.executionConfig } -func (a *arrayExecutionContext) CurrentParallelism() uint32 { - return *a.currentParallelism -} - -func (a *arrayExecutionContext) IncrementParallelism() uint32 { - *a.currentParallelism = *a.currentParallelism + 1 - return *a.currentParallelism -} - -func newArrayExecutionContext(executionContext executors.ExecutionContext, subNodeIndex int, currentParallelism *uint32, maxParallelism uint32) *arrayExecutionContext { +func newArrayExecutionContext(executionContext executors.ExecutionContext, subNodeIndex int) *arrayExecutionContext { executionConfig := executionContext.GetExecutionConfig() if executionConfig.EnvironmentVariables == nil { executionConfig.EnvironmentVariables = make(map[string]string) } executionConfig.EnvironmentVariables[JobIndexVarName] = FlyteK8sArrayIndexVarName executionConfig.EnvironmentVariables[FlyteK8sArrayIndexVarName] = strconv.Itoa(subNodeIndex) - executionConfig.MaxParallelism = maxParallelism + executionConfig.MaxParallelism = 0 // hardcoded to 0 because parallelism is handled by the array node return &arrayExecutionContext{ - ExecutionContext: executionContext, - executionConfig: executionConfig, - currentParallelism: currentParallelism, + ExecutionContext: executionContext, + executionConfig: executionConfig, } } diff --git a/flytepropeller/pkg/controller/nodes/array/handler.go b/flytepropeller/pkg/controller/nodes/array/handler.go index 00a9fc747e..f1e2ef64fc 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler.go +++ b/flytepropeller/pkg/controller/nodes/array/handler.go @@ -42,11 +42,13 @@ var ( // arrayNodeHandler is a handle implementation for processing array nodes type arrayNodeHandler struct { - eventConfig *config.EventConfig - metrics metrics - nodeExecutor interfaces.Node - pluginStateBytesNotStarted []byte - pluginStateBytesStarted []byte + eventConfig *config.EventConfig + gatherOutputsRequestChannel chan *gatherOutputsRequest + metrics metrics + nodeExecutionRequestChannel chan *nodeExecutionRequest + nodeExecutor interfaces.Node + pluginStateBytesNotStarted []byte + pluginStateBytesStarted []byte } // metrics encapsulates the prometheus metrics for this handler @@ -70,7 +72,6 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut messageCollector := errorcollector.NewErrorMessageCollector() switch arrayNodeState.Phase { case v1alpha1.ArrayNodePhaseExecuting, v1alpha1.ArrayNodePhaseFailing: - currentParallelism := uint32(0) for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) @@ -81,7 +82,7 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut // create array contexts arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, _, err := - a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism, eventRecorder) + a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, eventRecorder) if err != nil { return err } @@ -124,7 +125,6 @@ func (a *arrayNodeHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExe messageCollector := errorcollector.NewErrorMessageCollector() switch arrayNodeState.Phase { case v1alpha1.ArrayNodePhaseExecuting, v1alpha1.ArrayNodePhaseFailing, v1alpha1.ArrayNodePhaseSucceeding: - currentParallelism := uint32(0) for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) @@ -135,7 +135,7 @@ func (a *arrayNodeHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExe // create array contexts arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, _, err := - a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism, eventRecorder) + a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, eventRecorder) if err != nil { return err } @@ -242,8 +242,12 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu arrayNodeState.Phase = v1alpha1.ArrayNodePhaseExecuting case v1alpha1.ArrayNodePhaseExecuting: // process array node subNodes - currentParallelism := uint32(0) - messageCollector := errorcollector.NewErrorMessageCollector() + currentParallelism := int(arrayNode.GetParallelism()) + if currentParallelism == 0 { + currentParallelism = len(arrayNodeState.SubNodePhases.GetItems()) + } + + nodeExecutionRequests := make([]*nodeExecutionRequest, 0, currentParallelism) for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) taskPhase := int(arrayNodeState.SubNodeTaskPhases.GetItem(i)) @@ -254,43 +258,97 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } // create array contexts + subNodeEventRecorder := newArrayEventRecorder(nCtx.EventsRecorder()) arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, subNodeStatus, err := - a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism, eventRecorder) + a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, subNodeEventRecorder) if err != nil { return handler.UnknownTransition, err } - // execute subNode - _, err = arrayNodeExecutor.RecursiveNodeHandler(ctx, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec) - if err != nil { - return handler.UnknownTransition, err + nodeExecutionRequest := &nodeExecutionRequest{ + ctx: ctx, + index: i, + nodePhase: nodePhase, + taskPhase: taskPhase, + nodeExecutor: arrayNodeExecutor, + executionContext: arrayExecutionContext, + dagStructure: arrayDAGStructure, + nodeLookup: arrayNodeLookup, + subNodeSpec: subNodeSpec, + subNodeStatus: subNodeStatus, + arrayEventRecorder: subNodeEventRecorder, + responseChannel: make(chan struct { + interfaces.NodeStatus + error + }, 1), + } + + nodeExecutionRequests = append(nodeExecutionRequests, nodeExecutionRequest) + a.nodeExecutionRequestChannel <- nodeExecutionRequest + + // TODO - this is a naive implementation of parallelism, if we want to support more + // complex subNodes (ie. dynamics / subworkflows) we need to revisit this so that + // parallelism is handled during subNode evaluations. + currentParallelism-- + if currentParallelism == 0 { + break + } + } + + workerErrorCollector := errorcollector.NewErrorMessageCollector() + subNodeFailureCollector := errorcollector.NewErrorMessageCollector() + for i, nodeExecutionRequest := range nodeExecutionRequests { + nodeExecutionResponse := <-nodeExecutionRequest.responseChannel + if nodeExecutionResponse.error != nil { + workerErrorCollector.Collect(i, nodeExecutionResponse.error.Error()) + continue } + index := nodeExecutionRequest.index + subNodeStatus := nodeExecutionRequest.subNodeStatus + // capture subNode error if exists - if subNodeStatus.Error != nil { - messageCollector.Collect(i, subNodeStatus.Error.Message) + if nodeExecutionRequest.subNodeStatus.Error != nil { + subNodeFailureCollector.Collect(index, subNodeStatus.Error.Message) } - // process events - eventRecorder.process(ctx, nCtx, i, subNodeStatus.GetAttempts()) + // process events by copying from internal event recorder + if arrayEventRecorder, ok := nodeExecutionRequest.arrayEventRecorder.(*externalResourcesEventRecorder); ok { + for _, event := range arrayEventRecorder.taskEvents { + if err := eventRecorder.RecordTaskEvent(ctx, event, a.eventConfig); err != nil { + return handler.UnknownTransition, err + } + } + for _, event := range arrayEventRecorder.nodeEvents { + if err := eventRecorder.RecordNodeEvent(ctx, event, a.eventConfig); err != nil { + return handler.UnknownTransition, err + } + } + } + eventRecorder.process(ctx, nCtx, index, subNodeStatus.GetAttempts()) // update subNode state - arrayNodeState.SubNodePhases.SetItem(i, uint64(subNodeStatus.GetPhase())) + arrayNodeState.SubNodePhases.SetItem(index, uint64(subNodeStatus.GetPhase())) if subNodeStatus.GetTaskNodeStatus() == nil { // resetting task phase because during retries we clear the GetTaskNodeStatus - arrayNodeState.SubNodeTaskPhases.SetItem(i, uint64(0)) + arrayNodeState.SubNodeTaskPhases.SetItem(index, uint64(0)) } else { - arrayNodeState.SubNodeTaskPhases.SetItem(i, uint64(subNodeStatus.GetTaskNodeStatus().GetPhase())) + arrayNodeState.SubNodeTaskPhases.SetItem(index, uint64(subNodeStatus.GetTaskNodeStatus().GetPhase())) } - arrayNodeState.SubNodeRetryAttempts.SetItem(i, uint64(subNodeStatus.GetAttempts())) - arrayNodeState.SubNodeSystemFailures.SetItem(i, uint64(subNodeStatus.GetSystemFailures())) + arrayNodeState.SubNodeRetryAttempts.SetItem(index, uint64(subNodeStatus.GetAttempts())) + arrayNodeState.SubNodeSystemFailures.SetItem(index, uint64(subNodeStatus.GetSystemFailures())) // increment task phase version if subNode phase or task phase changed - if subNodeStatus.GetPhase() != nodePhase || subNodeStatus.GetTaskNodeStatus().GetPhase() != taskPhase { + if subNodeStatus.GetPhase() != nodeExecutionRequest.nodePhase || subNodeStatus.GetTaskNodeStatus().GetPhase() != nodeExecutionRequest.taskPhase { incrementTaskPhaseVersion = true } } + // if any workers failed then return the error + if workerErrorCollector.Length() > 0 { + return handler.UnknownTransition, fmt.Errorf("worker error(s) encountered: %s", workerErrorCollector.Summary(events.MaxErrorMessageLength)) + } + // process phases of subNodes to determine overall `ArrayNode` phase successCount := 0 failedCount := 0 @@ -321,7 +379,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // if there is a failing node set the error message if it has not been previous set if failingCount > 0 && arrayNodeState.Error == nil { arrayNodeState.Error = &idlcore.ExecutionError{ - Message: messageCollector.Summary(events.MaxErrorMessageLength), + Message: subNodeFailureCollector.Summary(events.MaxErrorMessageLength), } } @@ -349,33 +407,51 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu nil, )), nil case v1alpha1.ArrayNodePhaseSucceeding: - outputLiterals := make(map[string]*idlcore.Literal) + gatherOutputsRequests := make([]*gatherOutputsRequest, 0, len(arrayNodeState.SubNodePhases.GetItems())) for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) + gatherOutputsRequest := &gatherOutputsRequest{ + ctx: ctx, + responseChannel: make(chan struct { + literalMap map[string]*idlcore.Literal + error + }, 1), + } if nodePhase != v1alpha1.NodePhaseSucceeded { // retrieve output variables from task template - var outputVariables map[string]*idlcore.Variable + outputLiterals := make(map[string]*idlcore.Literal) task, err := nCtx.ExecutionContext().GetTask(*arrayNode.GetSubNodeSpec().TaskRef) if err != nil { // Should never happen - return handler.UnknownTransition, err + gatherOutputsRequest.responseChannel <- struct { + literalMap map[string]*idlcore.Literal + error + }{nil, err} + continue } if task.CoreTask() != nil && task.CoreTask().Interface != nil && task.CoreTask().Interface.Outputs != nil { - outputVariables = task.CoreTask().Interface.Outputs.Variables + for name := range task.CoreTask().Interface.Outputs.Variables { + outputLiterals[name] = nilLiteral + } } - // append nil literal for all output variables - for name := range outputVariables { - appendLiteral(name, nilLiteral, outputLiterals, len(arrayNodeState.SubNodePhases.GetItems())) - } + gatherOutputsRequest.responseChannel <- struct { + literalMap map[string]*idlcore.Literal + error + }{outputLiterals, nil} } else { // initialize subNode reader - currentAttempt := uint32(arrayNodeState.SubNodeRetryAttempts.GetItem(i)) - subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(i), strconv.Itoa(int(currentAttempt))) + currentAttempt := int(arrayNodeState.SubNodeRetryAttempts.GetItem(i)) + subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, + strconv.Itoa(i), strconv.Itoa(currentAttempt)) if err != nil { - return handler.UnknownTransition, err + gatherOutputsRequest.responseChannel <- struct { + literalMap map[string]*idlcore.Literal + error + }{nil, err} + continue } // checkpoint paths are not computed here because this function is only called when writing @@ -383,22 +459,33 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu outputPaths := ioutils.NewCheckpointRemoteFilePaths(ctx, nCtx.DataStore(), subOutputDir, ioutils.NewRawOutputPaths(ctx, subDataDir), "") reader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, nCtx.MaxDatasetSizeBytes()) - // read outputs - outputs, executionErr, err := reader.Read(ctx) - if err != nil { - return handler.UnknownTransition, err - } else if executionErr != nil { - return handler.UnknownTransition, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), - "execution error ArrayNode output, bad state: %s", executionErr.String()) - } + gatherOutputsRequest.reader = &reader + a.gatherOutputsRequestChannel <- gatherOutputsRequest + } - // copy individual subNode output literals into a collection of output literals - for name, literal := range outputs.GetLiterals() { - appendLiteral(name, literal, outputLiterals, len(arrayNodeState.SubNodePhases.GetItems())) - } + gatherOutputsRequests = append(gatherOutputsRequests, gatherOutputsRequest) + } + + outputLiterals := make(map[string]*idlcore.Literal) + workerErrorCollector := errorcollector.NewErrorMessageCollector() + for i, gatherOutputsRequest := range gatherOutputsRequests { + outputResponse := <-gatherOutputsRequest.responseChannel + if outputResponse.error != nil { + workerErrorCollector.Collect(i, outputResponse.error.Error()) + continue + } + + // append literal for all output variables + for name, literal := range outputResponse.literalMap { + appendLiteral(name, literal, outputLiterals, len(arrayNodeState.SubNodePhases.GetItems())) } } + // if any workers failed then return the error + if workerErrorCollector.Length() > 0 { + return handler.UnknownTransition, fmt.Errorf("worker error(s) encountered: %s", workerErrorCollector.Summary(events.MaxErrorMessageLength)) + } + outputLiteralMap := &idlcore.LiteralMap{ Literals: outputLiterals, } @@ -460,6 +547,18 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // Setup handles any initialization requirements for this handler func (a *arrayNodeHandler) Setup(_ context.Context, _ interfaces.SetupContext) error { + // start workers + for i := 0; i < config.GetConfig().NodeExecutionWorkerCount; i++ { + worker := worker{ + gatherOutputsRequestChannel: a.gatherOutputsRequestChannel, + nodeExecutionRequestChannel: a.nodeExecutionRequestChannel, + } + + go func() { + worker.run() + }() + } + return nil } @@ -478,11 +577,13 @@ func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope pr arrayScope := scope.NewSubScope("array") return &arrayNodeHandler{ - eventConfig: eventConfig, - metrics: newMetrics(arrayScope), - nodeExecutor: nodeExecutor, - pluginStateBytesNotStarted: pluginStateBytesNotStarted, - pluginStateBytesStarted: pluginStateBytesStarted, + eventConfig: eventConfig, + gatherOutputsRequestChannel: make(chan *gatherOutputsRequest), + metrics: newMetrics(arrayScope), + nodeExecutionRequestChannel: make(chan *nodeExecutionRequest), + nodeExecutor: nodeExecutor, + pluginStateBytesNotStarted: pluginStateBytesNotStarted, + pluginStateBytesStarted: pluginStateBytesStarted, }, nil } @@ -491,7 +592,7 @@ func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope pr // but need many different execution details, for example setting input values as a singular item rather than a collection, // injecting environment variables for flytekit maptask execution, aggregating eventing so that rather than tracking state for // each subnode individually it sends a single event for the whole ArrayNode, and many more. -func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx interfaces.NodeExecutionContext, arrayNodeState *handler.ArrayNodeState, arrayNode v1alpha1.ExecutableArrayNode, subNodeIndex int, currentParallelism *uint32, eventRecorder arrayEventRecorder) ( +func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx interfaces.NodeExecutionContext, arrayNodeState *handler.ArrayNodeState, arrayNode v1alpha1.ExecutableArrayNode, subNodeIndex int, eventRecorder arrayEventRecorder) ( interfaces.Node, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, *v1alpha1.NodeSpec, *v1alpha1.NodeStatus, error) { nodePhase := v1alpha1.NodePhase(arrayNodeState.SubNodePhases.GetItem(subNodeIndex)) @@ -556,12 +657,10 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter if err != nil { return nil, nil, nil, nil, nil, nil, err } - arrayExecutionContext := newArrayExecutionContext( - executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo), - subNodeIndex, currentParallelism, arrayNode.GetParallelism()) + arrayExecutionContext := newArrayExecutionContext(executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo), subNodeIndex) arrayNodeExecutionContextBuilder := newArrayNodeExecutionContextBuilder(a.nodeExecutor.GetNodeExecutionContextBuilder(), - subNodeID, subNodeIndex, subNodeStatus, inputReader, currentParallelism, arrayNode.GetParallelism(), eventRecorder) + subNodeID, subNodeIndex, subNodeStatus, inputReader, eventRecorder) arrayNodeExecutor := a.nodeExecutor.WithNodeExecutionContextBuilder(arrayNodeExecutionContextBuilder) return arrayNodeExecutor, arrayExecutionContext, &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec, subNodeStatus, nil diff --git a/flytepropeller/pkg/controller/nodes/array/handler_test.go b/flytepropeller/pkg/controller/nodes/array/handler_test.go index f9086218c2..b2e85c0979 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/array/handler_test.go @@ -62,7 +62,13 @@ func createArrayNodeHandler(ctx context.Context, t *testing.T, nodeHandler inter assert.NoError(t, err) // return ArrayNodeHandler - return New(nodeExecutor, eventConfig, scope) + arrayNodeHandler, err := New(nodeExecutor, eventConfig, scope) + if err != nil { + return nil, err + } + + err = arrayNodeHandler.Setup(ctx, nil) + return arrayNodeHandler, err } func createNodeExecutionContext(dataStore *storage.DataStore, eventRecorder interfaces.EventRecorder, outputVariables []string, @@ -496,11 +502,10 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { }, subNodeTransitions: []handler.Transition{ handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), - handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), }, expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, expectedTransitionPhase: handler.EPhaseRunning, - expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_QUEUED}, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING}, }, { name: "AllSubNodesSuccedeed", diff --git a/flytepropeller/pkg/controller/nodes/array/node_execution_context.go b/flytepropeller/pkg/controller/nodes/array/node_execution_context.go index 17d46d2944..6ef7bb01c1 100644 --- a/flytepropeller/pkg/controller/nodes/array/node_execution_context.go +++ b/flytepropeller/pkg/controller/nodes/array/node_execution_context.go @@ -104,8 +104,10 @@ func (a *arrayNodeExecutionContext) TaskReader() interfaces.TaskReader { return a.taskReader } -func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext, inputReader io.InputReader, eventRecorder arrayEventRecorder, subNodeIndex int, nodeStatus *v1alpha1.NodeStatus, currentParallelism *uint32, maxParallelism uint32) *arrayNodeExecutionContext { - arrayExecutionContext := newArrayExecutionContext(nodeExecutionContext.ExecutionContext(), subNodeIndex, currentParallelism, maxParallelism) +func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext, inputReader io.InputReader, + eventRecorder arrayEventRecorder, subNodeIndex int, nodeStatus *v1alpha1.NodeStatus) *arrayNodeExecutionContext { + + arrayExecutionContext := newArrayExecutionContext(nodeExecutionContext.ExecutionContext(), subNodeIndex) return &arrayNodeExecutionContext{ NodeExecutionContext: nodeExecutionContext, eventRecorder: eventRecorder, diff --git a/flytepropeller/pkg/controller/nodes/array/node_execution_context_builder.go b/flytepropeller/pkg/controller/nodes/array/node_execution_context_builder.go index 6d0cfd3bfb..b66ae4a54d 100644 --- a/flytepropeller/pkg/controller/nodes/array/node_execution_context_builder.go +++ b/flytepropeller/pkg/controller/nodes/array/node_execution_context_builder.go @@ -10,14 +10,12 @@ import ( ) type arrayNodeExecutionContextBuilder struct { - nCtxBuilder interfaces.NodeExecutionContextBuilder - subNodeID v1alpha1.NodeID - subNodeIndex int - subNodeStatus *v1alpha1.NodeStatus - inputReader io.InputReader - currentParallelism *uint32 - maxParallelism uint32 - eventRecorder arrayEventRecorder + nCtxBuilder interfaces.NodeExecutionContextBuilder + subNodeID v1alpha1.NodeID + subNodeIndex int + subNodeStatus *v1alpha1.NodeStatus + inputReader io.InputReader + eventRecorder arrayEventRecorder } func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, @@ -31,23 +29,21 @@ func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context if currentNodeID == a.subNodeID { // overwrite NodeExecutionContext for ArrayNode execution - nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.eventRecorder, a.subNodeIndex, a.subNodeStatus, a.currentParallelism, a.maxParallelism) + nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.eventRecorder, a.subNodeIndex, a.subNodeStatus) } return nCtx, nil } -func newArrayNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder, subNodeID v1alpha1.NodeID, subNodeIndex int, subNodeStatus *v1alpha1.NodeStatus, - inputReader io.InputReader, currentParallelism *uint32, maxParallelism uint32, eventRecorder arrayEventRecorder) interfaces.NodeExecutionContextBuilder { +func newArrayNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder, subNodeID v1alpha1.NodeID, subNodeIndex int, + subNodeStatus *v1alpha1.NodeStatus, inputReader io.InputReader, eventRecorder arrayEventRecorder) interfaces.NodeExecutionContextBuilder { return &arrayNodeExecutionContextBuilder{ - nCtxBuilder: nCtxBuilder, - subNodeID: subNodeID, - subNodeIndex: subNodeIndex, - subNodeStatus: subNodeStatus, - inputReader: inputReader, - currentParallelism: currentParallelism, - maxParallelism: maxParallelism, - eventRecorder: eventRecorder, + nCtxBuilder: nCtxBuilder, + subNodeID: subNodeID, + subNodeIndex: subNodeIndex, + subNodeStatus: subNodeStatus, + inputReader: inputReader, + eventRecorder: eventRecorder, } } diff --git a/flytepropeller/pkg/controller/nodes/array/worker.go b/flytepropeller/pkg/controller/nodes/array/worker.go new file mode 100644 index 0000000000..b5b5db49da --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/array/worker.go @@ -0,0 +1,105 @@ +package array + +import ( + "context" + "fmt" + "runtime/debug" + + idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/ioutils" + "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flyte/flytestdlib/logger" +) + +// nodeExecutionRequest is a request to execute an ArrayNode subNode +type nodeExecutionRequest struct { + ctx context.Context + index int + nodePhase v1alpha1.NodePhase + taskPhase int + nodeExecutor interfaces.Node + executionContext executors.ExecutionContext + dagStructure executors.DAGStructure + nodeLookup executors.NodeLookup + subNodeSpec *v1alpha1.NodeSpec + subNodeStatus *v1alpha1.NodeStatus + arrayEventRecorder arrayEventRecorder + responseChannel chan struct { + interfaces.NodeStatus + error + } +} + +// gatherOutputsRequest is a request to read outputs from an ArrayNode subNode +type gatherOutputsRequest struct { + ctx context.Context + reader *ioutils.RemoteFileOutputReader + responseChannel chan struct { + literalMap map[string]*idlcore.Literal + error + } +} + +// worker is an entity that is used to parallelize I/O bound operations for ArrayNode execution +type worker struct { + gatherOutputsRequestChannel chan *gatherOutputsRequest + nodeExecutionRequestChannel chan *nodeExecutionRequest +} + +// run starts the main handle loop for the worker +func (w *worker) run() { + for { + select { + case nodeExecutionRequest := <-w.nodeExecutionRequestChannel: + var nodeStatus interfaces.NodeStatus + var err error + func() { + defer func() { + if r := recover(); r != nil { + stack := debug.Stack() + err = fmt.Errorf("panic when executing ArrayNode subNode, Stack: [%s]", string(stack)) + logger.Errorf(nodeExecutionRequest.ctx, err.Error()) + } + }() + + // execute RecurseNodeHandler on node + nodeStatus, err = nodeExecutionRequest.nodeExecutor.RecursiveNodeHandler(nodeExecutionRequest.ctx, nodeExecutionRequest.executionContext, + nodeExecutionRequest.dagStructure, nodeExecutionRequest.nodeLookup, nodeExecutionRequest.subNodeSpec) + }() + + nodeExecutionRequest.responseChannel <- struct { + interfaces.NodeStatus + error + }{nodeStatus, err} + case gatherOutputsRequest := <-w.gatherOutputsRequestChannel: + var literalMap map[string]*idlcore.Literal + var err error + func() { + defer func() { + if r := recover(); r != nil { + stack := debug.Stack() + err = fmt.Errorf("panic when executing ArrayNode subNode, Stack: [%s]", string(stack)) + logger.Errorf(gatherOutputsRequest.ctx, err.Error()) + } + }() + + // read outputs + outputs, executionErr, gatherErr := gatherOutputsRequest.reader.Read(gatherOutputsRequest.ctx) + if gatherErr != nil { + err = gatherErr + } else if executionErr != nil { + err = fmt.Errorf("%s", executionErr.String()) + } else { + literalMap = outputs.GetLiterals() + } + }() + + gatherOutputsRequest.responseChannel <- struct { + literalMap map[string]*idlcore.Literal + error + }{literalMap, nil} + } + } +}