From 2b76a91611267b6eacf4c88c4274139e61d8a506 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Thu, 21 Dec 2023 19:14:36 +0000 Subject: [PATCH] Avoid making nodeExecContext public Signed-off-by: Thomas Newton --- .../pkg/controller/nodes/node_exec_context.go | 40 +++++++++---------- .../pkg/controller/workflow/executor_test.go | 6 ++- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/flytepropeller/pkg/controller/nodes/node_exec_context.go b/flytepropeller/pkg/controller/nodes/node_exec_context.go index e29972b0297..f42f8b0324c 100644 --- a/flytepropeller/pkg/controller/nodes/node_exec_context.go +++ b/flytepropeller/pkg/controller/nodes/node_exec_context.go @@ -118,7 +118,7 @@ func (e nodeExecMetadata) GetLabels() map[string]string { return e.nodeLabels } -type NodeExecContext struct { +type nodeExecContext struct { store *storage.DataStore tr interfaces.TaskReader md interfaces.NodeExecutionMetadata @@ -135,78 +135,78 @@ type NodeExecContext struct { ic executors.ExecutionContext } -func (e NodeExecContext) ExecutionContext() executors.ExecutionContext { +func (e nodeExecContext) ExecutionContext() executors.ExecutionContext { return e.ic } -func (e NodeExecContext) ContextualNodeLookup() executors.NodeLookup { +func (e nodeExecContext) ContextualNodeLookup() executors.NodeLookup { return e.nl } -func (e NodeExecContext) OutputShardSelector() ioutils.ShardSelector { +func (e nodeExecContext) OutputShardSelector() ioutils.ShardSelector { return e.shardSelector } -func (e NodeExecContext) RawOutputPrefix() storage.DataReference { +func (e nodeExecContext) RawOutputPrefix() storage.DataReference { return e.rawOutputPrefix } -func (e NodeExecContext) EnqueueOwnerFunc() func() error { +func (e nodeExecContext) EnqueueOwnerFunc() func() error { return e.enqueueOwner } -func (e NodeExecContext) TaskReader() interfaces.TaskReader { +func (e nodeExecContext) TaskReader() interfaces.TaskReader { return e.tr } -func (e NodeExecContext) NodeStateReader() interfaces.NodeStateReader { +func (e nodeExecContext) NodeStateReader() interfaces.NodeStateReader { return e.nsm } -func (e NodeExecContext) NodeStateWriter() interfaces.NodeStateWriter { +func (e nodeExecContext) NodeStateWriter() interfaces.NodeStateWriter { return e.nsm } -func (e NodeExecContext) DataStore() *storage.DataStore { +func (e nodeExecContext) DataStore() *storage.DataStore { return e.store } -func (e NodeExecContext) InputReader() io.InputReader { +func (e nodeExecContext) InputReader() io.InputReader { return e.inputs } -func (e NodeExecContext) EventsRecorder() interfaces.EventRecorder { +func (e nodeExecContext) EventsRecorder() interfaces.EventRecorder { return e.eventRecorder } -func (e NodeExecContext) NodeID() v1alpha1.NodeID { +func (e nodeExecContext) NodeID() v1alpha1.NodeID { return e.node.GetID() } -func (e NodeExecContext) Node() v1alpha1.ExecutableNode { +func (e nodeExecContext) Node() v1alpha1.ExecutableNode { return e.node } -func (e NodeExecContext) CurrentAttempt() uint32 { +func (e nodeExecContext) CurrentAttempt() uint32 { return e.nodeStatus.GetAttempts() } -func (e NodeExecContext) NodeStatus() v1alpha1.ExecutableNodeStatus { +func (e nodeExecContext) NodeStatus() v1alpha1.ExecutableNodeStatus { return e.nodeStatus } -func (e NodeExecContext) NodeExecutionMetadata() interfaces.NodeExecutionMetadata { +func (e nodeExecContext) NodeExecutionMetadata() interfaces.NodeExecutionMetadata { return e.md } -func (e NodeExecContext) MaxDatasetSizeBytes() int64 { +func (e nodeExecContext) MaxDatasetSizeBytes() int64 { return e.maxDatasetSizeBytes } func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext executors.ExecutionContext, nl executors.NodeLookup, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, inputs io.InputReader, interruptible bool, interruptibleFailureThreshold int32, maxDatasetSize int64, taskEventRecorder events.TaskEventRecorder, nodeEventRecorder events.NodeEventRecorder, tr interfaces.TaskReader, nsm *nodeStateManager, - enqueueOwner func() error, rawOutputPrefix storage.DataReference, outputShardSelector ioutils.ShardSelector) *NodeExecContext { + enqueueOwner func() error, rawOutputPrefix storage.DataReference, outputShardSelector ioutils.ShardSelector) *nodeExecContext { md := nodeExecMetadata{ Meta: execContext, @@ -230,7 +230,7 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext nodeLabels[NodeInterruptibleLabel] = strconv.FormatBool(interruptible) md.nodeLabels = nodeLabels - return &NodeExecContext{ + return &nodeExecContext{ md: md, store: store, node: node, diff --git a/flytepropeller/pkg/controller/workflow/executor_test.go b/flytepropeller/pkg/controller/workflow/executor_test.go index 0b46c2edfd7..16285899856 100644 --- a/flytepropeller/pkg/controller/workflow/executor_test.go +++ b/flytepropeller/pkg/controller/workflow/executor_test.go @@ -72,6 +72,10 @@ type fakeRemoteWritePlugin struct { t assert.TestingT } +type fakeNodeExecContext interface { + Node() v1alpha1.ExecutableNode +} + func (f fakeRemoteWritePlugin) Handle(ctx context.Context, tCtx pluginCore.TaskExecutionContext) (pluginCore.Transition, error) { logger.Infof(ctx, "----------------------------------------------------------------------------------------------") logger.Infof(ctx, "Handle called for %s", tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) @@ -517,7 +521,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { h := &nodemocks.NodeHandler{} h.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) - startNodeMatcher := mock.MatchedBy(func(nodeExecContext *nodes.NodeExecContext) bool { + startNodeMatcher := mock.MatchedBy(func(nodeExecContext fakeNodeExecContext) bool { return nodeExecContext.Node().IsStartNode() }) h.OnHandleMatch(mock.Anything, startNodeMatcher).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil)