diff --git a/flytepropeller/pkg/controller/config/config.go b/flytepropeller/pkg/controller/config/config.go index 89122da3ed..419386eddd 100644 --- a/flytepropeller/pkg/controller/config/config.go +++ b/flytepropeller/pkg/controller/config/config.go @@ -115,8 +115,11 @@ var ( }, ClusterID: "propeller", CreateFlyteWorkflowCRD: false, - ArrayNodeEventVersion: 0, NodeExecutionWorkerCount: 8, + ArrayNode: ArrayNodeConfig{ + EventVersion: 0, + DefaultParallelismBehavior: ParallelismBehaviorUnlimited, + }, } ) @@ -156,8 +159,8 @@ type Config struct { ExcludeDomainLabel []string `json:"exclude-domain-label" pflag:",Exclude the specified domain label from the k8s FlyteWorkflow CRD label selector"` ClusterID string `json:"cluster-id" pflag:",Unique cluster id running this flytepropeller instance with which to annotate execution events"` CreateFlyteWorkflowCRD bool `json:"create-flyteworkflow-crd" pflag:",Enable creation of the FlyteWorkflow CRD on startup"` - ArrayNodeEventVersion int `json:"array-node-event-version" pflag:",ArrayNode eventing version. 0 => legacy (drop-in replacement for maptask), 1 => new"` NodeExecutionWorkerCount int `json:"node-execution-worker-count" pflag:",Number of workers to evaluate node executions, currently only used for array nodes"` + ArrayNode ArrayNodeConfig `json:"array-node-config,omitempty" pflag:",Configuration for array nodes"` } // KubeClientConfig contains the configuration used by flytepropeller to configure its internal Kubernetes Client. @@ -258,6 +261,31 @@ type EventConfig struct { FallbackToOutputReference bool `json:"fallback-to-output-reference" pflag:",Whether output data should be sent by reference when it is too large to be sent inline in execution events."` } +// ParallelismBehavior defines how ArrayNode should handle subNode parallelism by default +type ParallelismBehavior = string + +const ( + // ParallelismBehaviorHybrid means that ArrayNode will adhere to the parallelism defined in the + // ArrayNode exactly. This means `nil` will use the workflow parallelism, and 0 will have + // unlimited parallelism. + ParallelismBehaviorHybrid ParallelismBehavior = "hybrid" + + // ParallelismBehaviorUnlimited means that ArrayNode subNodes will be evaluated with unlimited + // parallelism for both nil and 0. If a non-default (ie. nil / zero) parallelism is set, then + // ArrayNode will adhere to that value. + ParallelismBehaviorUnlimited ParallelismBehavior = "unlimited" + + // ParallelismBehaviorWorkflow means that ArrayNode subNodes will be evaluated using the + // configured workflow parallelism for both nil and 0. If a non-default (ie. nil / zero) + // parallelism is set, then ArrayNode will adhere to that value. + ParallelismBehaviorWorkflow ParallelismBehavior = "workflow" +) + +type ArrayNodeConfig struct { + EventVersion int `json:"event-version" pflag:",ArrayNode eventing version. 0 => legacy (drop-in replacement for maptask), 1 => new"` + DefaultParallelismBehavior ParallelismBehavior `json:"default-parallelism-behavior" pflag:",Default parallelism behavior for array nodes"` +} + // GetConfig extracts the Configuration from the global config module in flytestdlib and returns the corresponding type-casted object. func GetConfig() *Config { return configSection.GetConfig().(*Config) diff --git a/flytepropeller/pkg/controller/config/config_flags.go b/flytepropeller/pkg/controller/config/config_flags.go index b055aad558..ea0b428c2f 100755 --- a/flytepropeller/pkg/controller/config/config_flags.go +++ b/flytepropeller/pkg/controller/config/config_flags.go @@ -108,7 +108,8 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "exclude-domain-label"), defaultConfig.ExcludeDomainLabel, "Exclude the specified domain label from the k8s FlyteWorkflow CRD label selector") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "cluster-id"), defaultConfig.ClusterID, "Unique cluster id running this flytepropeller instance with which to annotate execution events") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "create-flyteworkflow-crd"), defaultConfig.CreateFlyteWorkflowCRD, "Enable creation of the FlyteWorkflow CRD on startup") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "array-node-event-version"), defaultConfig.ArrayNodeEventVersion, "ArrayNode eventing version. 0 => legacy (drop-in replacement for maptask), 1 => new") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "node-execution-worker-count"), defaultConfig.NodeExecutionWorkerCount, "Number of workers to evaluate node executions, currently only used for array nodes") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "array-node-config.event-version"), defaultConfig.ArrayNode.EventVersion, "ArrayNode eventing version. 0 => legacy (drop-in replacement for maptask), 1 => new") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "array-node-config.default-parallelism-behavior"), defaultConfig.ArrayNode.DefaultParallelismBehavior, "Default parallelism behavior for array nodes") return cmdFlags } diff --git a/flytepropeller/pkg/controller/config/config_flags_test.go b/flytepropeller/pkg/controller/config/config_flags_test.go index 6f3c67b652..bce7238f60 100755 --- a/flytepropeller/pkg/controller/config/config_flags_test.go +++ b/flytepropeller/pkg/controller/config/config_flags_test.go @@ -911,28 +911,42 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_array-node-event-version", func(t *testing.T) { + t.Run("Test_node-execution-worker-count", func(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("array-node-event-version", testValue) - if vInt, err := cmdFlags.GetInt("array-node-event-version"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.ArrayNodeEventVersion) + cmdFlags.Set("node-execution-worker-count", testValue) + if vInt, err := cmdFlags.GetInt("node-execution-worker-count"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.NodeExecutionWorkerCount) } else { assert.FailNow(t, err.Error()) } }) }) - t.Run("Test_node-execution-worker-count", func(t *testing.T) { + t.Run("Test_array-node-config.event-version", func(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("node-execution-worker-count", testValue) - if vInt, err := cmdFlags.GetInt("node-execution-worker-count"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.NodeExecutionWorkerCount) + cmdFlags.Set("array-node-config.event-version", testValue) + if vInt, err := cmdFlags.GetInt("array-node-config.event-version"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.ArrayNode.EventVersion) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_array-node-config.default-parallelism-behavior", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("array-node-config.default-parallelism-behavior", testValue) + if vString, err := cmdFlags.GetString("array-node-config.default-parallelism-behavior"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.ArrayNode.DefaultParallelismBehavior) } else { assert.FailNow(t, err.Error()) diff --git a/flytepropeller/pkg/controller/nodes/array/event_recorder.go b/flytepropeller/pkg/controller/nodes/array/event_recorder.go index c2e0c96ed8..35120c069e 100644 --- a/flytepropeller/pkg/controller/nodes/array/event_recorder.go +++ b/flytepropeller/pkg/controller/nodes/array/event_recorder.go @@ -202,7 +202,7 @@ func (*passThroughEventRecorder) finalizeRequired(ctx context.Context) bool { } func newArrayEventRecorder(eventRecorder interfaces.EventRecorder) arrayEventRecorder { - if config.GetConfig().ArrayNodeEventVersion == 0 { + if config.GetConfig().ArrayNode.EventVersion == 0 { return &externalResourcesEventRecorder{ EventRecorder: eventRecorder, } diff --git a/flytepropeller/pkg/controller/nodes/array/handler.go b/flytepropeller/pkg/controller/nodes/array/handler.go index cb5787053d..0f9e95f19b 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler.go +++ b/flytepropeller/pkg/controller/nodes/array/handler.go @@ -252,26 +252,14 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu arrayNodeState.Phase = v1alpha1.ArrayNodePhaseExecuting case v1alpha1.ArrayNodePhaseExecuting: // process array node subNodes + remainingWorkflowParallelism := int(nCtx.ExecutionContext().GetExecutionConfig().MaxParallelism - nCtx.ExecutionContext().CurrentParallelism()) + incrementWorkflowParallelism, maxParallelism := inferParallelism(ctx, arrayNode.GetParallelism(), + config.GetConfig().ArrayNode.DefaultParallelismBehavior, remainingWorkflowParallelism, len(arrayNodeState.SubNodePhases.GetItems())) - availableParallelism := 0 - // using the workflow's parallelism if the array node parallelism is not set - useWorkflowParallelism := arrayNode.GetParallelism() == nil - if useWorkflowParallelism { - // greedily take all available slots - // TODO: This will need to be re-evaluated if we want to support dynamics & sub_workflows - currentParallelism := nCtx.ExecutionContext().CurrentParallelism() - maxParallelism := nCtx.ExecutionContext().GetExecutionConfig().MaxParallelism - availableParallelism = int(maxParallelism - currentParallelism) - } else { - availableParallelism = int(*arrayNode.GetParallelism()) - if availableParallelism == 0 { - availableParallelism = len(arrayNodeState.SubNodePhases.GetItems()) - } - } - - nodeExecutionRequests := make([]*nodeExecutionRequest, 0, availableParallelism) + nodeExecutionRequests := make([]*nodeExecutionRequest, 0, maxParallelism) + currentParallelism := 0 for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { - if availableParallelism == 0 { + if currentParallelism >= maxParallelism { break } @@ -315,10 +303,10 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // TODO - this is a naive implementation of parallelism, if we want to support more // complex subNodes (ie. dynamics / subworkflows) we need to revisit this so that // parallelism is handled during subNode evaluations + avoid deadlocks - if useWorkflowParallelism { + if incrementWorkflowParallelism { nCtx.ExecutionContext().IncrementParallelism() } - availableParallelism-- + currentParallelism++ } workerErrorCollector := errorcollector.NewErrorMessageCollector() @@ -418,6 +406,12 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // wait until all tasks have completed before declaring success arrayNodeState.Phase = v1alpha1.ArrayNodePhaseSucceeding } + + // if incrementWorkflowParallelism is not set then we need to increment the parallelism by one + // to indicate that the overall ArrayNode is still running + if !incrementWorkflowParallelism && arrayNodeState.Phase == v1alpha1.ArrayNodePhaseExecuting { + nCtx.ExecutionContext().IncrementParallelism() + } case v1alpha1.ArrayNodePhaseFailing: if err := a.Abort(ctx, nCtx, "ArrayNodeFailing"); err != nil { return handler.UnknownTransition, err diff --git a/flytepropeller/pkg/controller/nodes/array/handler_test.go b/flytepropeller/pkg/controller/nodes/array/handler_test.go index f7863731b4..ba0815fee6 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/array/handler_test.go @@ -461,6 +461,11 @@ func uint32Ptr(v uint32) *uint32 { func TestHandleArrayNodePhaseExecuting(t *testing.T) { ctx := context.Background() + + // setting default parallelism behavior on ArrayNode to "hybrid" to test the largest scope of functionality + flyteConfig := config.GetConfig() + flyteConfig.ArrayNode.DefaultParallelismBehavior = config.ParallelismBehaviorHybrid + minSuccessRatio := float32(0.5) // initialize universal variables @@ -511,6 +516,7 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_RUNNING}, + incrementParallelismCount: 1, }, { name: "StartOneSubNodeParallelism", @@ -529,12 +535,11 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING}, + incrementParallelismCount: 1, }, { - name: "UtilizeWfParallelismAllSubNodes", - parallelism: nil, - currentWfParallelism: 0, - incrementParallelismCount: 2, + name: "UtilizeWfParallelismAllSubNodes", + parallelism: nil, subNodePhases: []v1alpha1.NodePhase{ v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, @@ -550,12 +555,12 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_RUNNING}, + currentWfParallelism: 0, + incrementParallelismCount: 2, }, { - name: "UtilizeWfParallelismSomeSubNodes", - parallelism: nil, - currentWfParallelism: workflowMaxParallelism - 1, - incrementParallelismCount: 1, + name: "UtilizeWfParallelismSomeSubNodes", + parallelism: nil, subNodePhases: []v1alpha1.NodePhase{ v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, @@ -570,12 +575,12 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING}, + currentWfParallelism: workflowMaxParallelism - 1, + incrementParallelismCount: 1, }, { - name: "UtilizeWfParallelismNoSubNodes", - parallelism: nil, - currentWfParallelism: workflowMaxParallelism, - incrementParallelismCount: 0, + name: "UtilizeWfParallelismNoSubNodes", + parallelism: nil, subNodePhases: []v1alpha1.NodePhase{ v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseQueued, @@ -588,6 +593,8 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{}, + currentWfParallelism: workflowMaxParallelism, + incrementParallelismCount: 0, }, { name: "StartSubNodesNewAttempts", @@ -607,6 +614,7 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_RUNNING}, + incrementParallelismCount: 1, }, { name: "AllSubNodesSuccedeed", diff --git a/flytepropeller/pkg/controller/nodes/array/utils.go b/flytepropeller/pkg/controller/nodes/array/utils.go index 342ccca3f6..34e304b662 100644 --- a/flytepropeller/pkg/controller/nodes/array/utils.go +++ b/flytepropeller/pkg/controller/nodes/array/utils.go @@ -7,10 +7,12 @@ import ( idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/config" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/codex" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/k8s" + "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/flyteorg/flyte/flytestdlib/storage" ) @@ -62,6 +64,24 @@ func constructOutputReferences(ctx context.Context, nCtx interfaces.NodeExecutio return subDataDir, subOutputDir, nil } +func inferParallelism(ctx context.Context, parallelism *uint32, parallelismBehavior string, remainingWorkflowParallelism, arrayNodeSize int) (bool, int) { + if parallelism != nil && *parallelism > 0 { + // if parallelism is not defaulted - use it + return false, int(*parallelism) + } else if parallelismBehavior == config.ParallelismBehaviorWorkflow || (parallelism == nil && parallelismBehavior == config.ParallelismBehaviorHybrid) { + // if workflow level parallelism + return true, remainingWorkflowParallelism + } else if parallelismBehavior == config.ParallelismBehaviorUnlimited || + (parallelism != nil && *parallelism == 0 && parallelismBehavior == config.ParallelismBehaviorHybrid) { + // if unlimited parallelism + return false, arrayNodeSize + } + + logger.Warnf(ctx, "unable to infer ArrayNode parallelism configuration for parallelism:%v behavior:%v, defaulting to unlimited parallelism", + parallelism, parallelismBehavior) + return false, arrayNodeSize +} + func isTerminalNodePhase(nodePhase v1alpha1.NodePhase) bool { return nodePhase == v1alpha1.NodePhaseSucceeded || nodePhase == v1alpha1.NodePhaseFailed || nodePhase == v1alpha1.NodePhaseTimedOut || nodePhase == v1alpha1.NodePhaseSkipped || nodePhase == v1alpha1.NodePhaseRecovered diff --git a/flytepropeller/pkg/controller/nodes/array/utils_test.go b/flytepropeller/pkg/controller/nodes/array/utils_test.go index fde3d0fa80..2b2c030cd6 100644 --- a/flytepropeller/pkg/controller/nodes/array/utils_test.go +++ b/flytepropeller/pkg/controller/nodes/array/utils_test.go @@ -1,6 +1,7 @@ package array import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -34,3 +35,109 @@ func TestAppendLiteral(t *testing.T) { assert.Equal(t, 2, len(collection.Collection.Literals)) } } + +func TestInferParallelism(t *testing.T) { + ctx := context.TODO() + zero := uint32(0) + one := uint32(1) + + tests := []struct { + name string + parallelism *uint32 + parallelismBehavior string + remainingParallelism int + arrayNodeSize int + expectedIncrement bool + expectedMaxParallelism int + }{ + { + name: "NilParallelismWorkflowBehavior", + parallelism: nil, + parallelismBehavior: "workflow", + remainingParallelism: 2, + arrayNodeSize: 3, + expectedIncrement: true, + expectedMaxParallelism: 2, + }, + { + name: "NilParallelismHybridBehavior", + parallelism: nil, + parallelismBehavior: "hybrid", + remainingParallelism: 2, + arrayNodeSize: 3, + expectedIncrement: true, + expectedMaxParallelism: 2, + }, + { + name: "NilParallelismUnlimitedBehavior", + parallelism: nil, + parallelismBehavior: "unlimited", + remainingParallelism: 2, + arrayNodeSize: 3, + expectedIncrement: false, + expectedMaxParallelism: 3, + }, + { + name: "ZeroParallelismWorkflowBehavior", + parallelism: &zero, + parallelismBehavior: "workflow", + remainingParallelism: 2, + arrayNodeSize: 3, + expectedIncrement: true, + expectedMaxParallelism: 2, + }, + { + name: "ZeroParallelismHybridBehavior", + parallelism: &zero, + parallelismBehavior: "hybrid", + remainingParallelism: 2, + arrayNodeSize: 3, + expectedIncrement: false, + expectedMaxParallelism: 3, + }, + { + name: "ZeroParallelismUnlimitedBehavior", + parallelism: &zero, + parallelismBehavior: "unlimited", + remainingParallelism: 2, + arrayNodeSize: 3, + expectedIncrement: false, + expectedMaxParallelism: 3, + }, + { + name: "OneParallelismWorkflowBehavior", + parallelism: &one, + parallelismBehavior: "workflow", + remainingParallelism: 2, + arrayNodeSize: 3, + expectedIncrement: false, + expectedMaxParallelism: 1, + }, + { + name: "OneParallelismHybridBehavior", + parallelism: &one, + parallelismBehavior: "hybrid", + remainingParallelism: 2, + arrayNodeSize: 3, + expectedIncrement: false, + expectedMaxParallelism: 1, + }, + { + name: "OneParallelismUnlimitedBehavior", + parallelism: &one, + parallelismBehavior: "unlimited", + remainingParallelism: 2, + arrayNodeSize: 3, + expectedIncrement: false, + expectedMaxParallelism: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + increment, maxParallelism := inferParallelism(ctx, tt.parallelism, tt.parallelismBehavior, tt.remainingParallelism, tt.arrayNodeSize) + assert.Equal(t, tt.expectedIncrement, increment) + assert.Equal(t, tt.expectedMaxParallelism, maxParallelism) + }) + } +} diff --git a/flytepropeller/pkg/controller/nodes/transformers.go b/flytepropeller/pkg/controller/nodes/transformers.go index b034c5b90f..1c911b44f3 100644 --- a/flytepropeller/pkg/controller/nodes/transformers.go +++ b/flytepropeller/pkg/controller/nodes/transformers.go @@ -179,7 +179,7 @@ func ToNodeExecutionEvent(nodeExecID *core.NodeExecutionIdentifier, nev.IsParent = true } else if node.GetKind() == v1alpha1.NodeKindArray { nev.IsArray = true - if config.GetConfig().ArrayNodeEventVersion == 1 { + if config.GetConfig().ArrayNode.EventVersion == 1 { nev.IsParent = true } } else if dynamicNodePhase != v1alpha1.DynamicNodePhaseNone {