diff --git a/workflow/controller/retry_tweak.go b/workflow/controller/retry_tweak.go index e8e10a5f87fd..a2d1d3baceed 100644 --- a/workflow/controller/retry_tweak.go +++ b/workflow/controller/retry_tweak.go @@ -15,12 +15,23 @@ type RetryTweak = func(retryStrategy wfv1.RetryStrategy, nodes wfv1.Nodes, pod * func FindRetryNode(nodes wfv1.Nodes, nodeID string) *wfv1.NodeStatus { boundaryID := nodes[nodeID].BoundaryID boundaryNode := nodes[boundaryID] - templateName := boundaryNode.TemplateName - for _, node := range nodes { - if node.Type == wfv1.NodeTypeRetry && node.TemplateName == templateName { - return &node + if boundaryNode.TemplateName != "" { + templateName := boundaryNode.TemplateName + for _, node := range nodes { + if node.Type == wfv1.NodeTypeRetry && node.TemplateName == templateName { + return &node + } } } + if boundaryNode.TemplateRef != nil { + templateRef := boundaryNode.TemplateRef + for _, node := range nodes { + if node.Type == wfv1.NodeTypeRetry && node.TemplateRef != nil && node.TemplateRef.Name == templateRef.Name && node.TemplateRef.Template == templateRef.Template { + return &node + } + } + } + return nil } diff --git a/workflow/controller/retry_tweak_test.go b/workflow/controller/retry_tweak_test.go index 77403c146bbf..66e5d9d106c7 100644 --- a/workflow/controller/retry_tweak_test.go +++ b/workflow/controller/retry_tweak_test.go @@ -59,6 +59,25 @@ func TestFindRetryNode(t *testing.T) { Children: []string{}, TemplateName: "tmpl2", }, + "E1": wfv1.NodeStatus{ + ID: "E1", + Type: wfv1.NodeTypeRetry, + Phase: wfv1.NodeRunning, + BoundaryID: "A1", + Children: []string{}, + TemplateRef: &wfv1.TemplateRef{ + Name: "tmpl1", + Template: "tmpl3", + }, + }, + "E2": wfv1.NodeStatus{ + ID: "E2", + Type: wfv1.NodeTypePod, + Phase: wfv1.NodeRunning, + BoundaryID: "E1", + Children: []string{}, + TemplateName: "tmpl2", + }, } t.Run("Expect to find retry node", func(t *testing.T) { node := allNodes["B2"] @@ -68,4 +87,8 @@ func TestFindRetryNode(t *testing.T) { a := FindRetryNode(allNodes, "A1") assert.Nil(t, a) }) + t.Run("Expect to find retry node has TemplateRef", func(t *testing.T) { + node := allNodes["E1"] + assert.Equal(t, FindRetryNode(allNodes, "E2"), &node) + }) }