Skip to content

Commit

Permalink
fix: address all feedback
Browse files Browse the repository at this point in the history
Signed-off-by: isubasinghe <[email protected]>
  • Loading branch information
isubasinghe committed Oct 28, 2024
1 parent 9ed648e commit 074de52
Showing 1 changed file with 47 additions and 47 deletions.
94 changes: 47 additions & 47 deletions workflow/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -932,17 +932,20 @@ func getChildren(n *dagNode) map[string]bool {
return children
}

type resetFn func(string, bool)
type deleteFn func(string, bool)
type tillFn func(*dagNode) (bool, bool)
type resetFn func(string)
type deleteFn func(string)

func getTillFnNodeType(nodeType wfv1.NodeType) tillFn {
// untilFn is a function that returns two variables, the first indicates
// a `found` boolean while the second indicates if reset should be called.
type untilFn func(*dagNode) (bool, bool)

func getUntilFnNodeType(nodeType wfv1.NodeType) untilFn {
return func(n *dagNode) (bool, bool) {
return n.n.Type == nodeType, true
}
}

func consumeTill(n *dagNode, should tillFn, resetFunc resetFn) (*dagNode, error) {
func resetUntil(n *dagNode, should untilFn, resetFunc resetFn) (*dagNode, error) {
curr := n
for {
if curr == nil {
Expand All @@ -951,85 +954,86 @@ func consumeTill(n *dagNode, should tillFn, resetFunc resetFn) (*dagNode, error)

if foundNode, shouldReset := should(curr); foundNode {
if shouldReset {
resetFunc(curr.n.ID, true)
resetFunc(curr.n.ID)
}
return curr, nil
}
curr = curr.parent
}
}

func getTillBoundaryFn(boundaryID string) tillFn {
func getTillBoundaryFn(boundaryID string) untilFn {
return func(n *dagNode) (bool, bool) {
return n.n.ID == boundaryID, n.n.BoundaryID != ""
}
}

func consumeBoundaries(n *dagNode, resetFunc resetFn) (*dagNode, error) {
func resetBoundaries(n *dagNode, resetFunc resetFn) (*dagNode, error) {
curr := n
for {
if curr == nil {
return curr, nil
}
if curr.parent != nil && curr.parent.n.Type == wfv1.NodeTypeStepGroup {
resetFunc(curr.parent.n.ID, true)
resetFunc(curr.parent.n.ID)
}
seekingBoundaryID := curr.n.BoundaryID
if seekingBoundaryID == "" {
return curr.parent, nil
}
var err error
curr, err = consumeTill(curr, getTillBoundaryFn(seekingBoundaryID), resetFunc)
curr, err = resetUntil(curr, getTillBoundaryFn(seekingBoundaryID), resetFunc)
if err != nil {
return nil, err
}
}
}

func consumeStepGroup(n *dagNode, resetFunc resetFn) (*dagNode, error) {
return consumeTill(n, getTillFnNodeType(wfv1.NodeTypeStepGroup), resetFunc)
func resetStepGroup(n *dagNode, resetFunc resetFn) (*dagNode, error) {
return resetUntil(n, getUntilFnNodeType(wfv1.NodeTypeStepGroup), resetFunc)
}

func consumeSteps(n *dagNode, resetFunc resetFn) (*dagNode, error) {
n, err := consumeTill(n, getTillFnNodeType(wfv1.NodeTypeSteps), resetFunc)
func resetSteps(n *dagNode, resetFunc resetFn) (*dagNode, error) {
n, err := resetUntil(n, getUntilFnNodeType(wfv1.NodeTypeSteps), resetFunc)
if err != nil {
return nil, err
}
return consumeBoundaries(n, resetFunc)
return resetBoundaries(n, resetFunc)
}

func consumeTaskGroup(n *dagNode, resetFunc resetFn) (*dagNode, error) {
return consumeTill(n, getTillFnNodeType(wfv1.NodeTypeTaskGroup), resetFunc)
func resetTaskGroup(n *dagNode, resetFunc resetFn) (*dagNode, error) {
return resetUntil(n, getUntilFnNodeType(wfv1.NodeTypeTaskGroup), resetFunc)
}

func consumeDAG(n *dagNode, resetFunc resetFn) (*dagNode, error) {
n, err := consumeTill(n, getTillFnNodeType(wfv1.NodeTypeDAG), resetFunc)
func resetDAG(n *dagNode, resetFunc resetFn) (*dagNode, error) {
n, err := resetUntil(n, getUntilFnNodeType(wfv1.NodeTypeDAG), resetFunc)
if err != nil {
return nil, err
}
return consumeBoundaries(n, resetFunc)
return resetBoundaries(n, resetFunc)
}

func consumePod(n *dagNode, resetFunc resetFn, addToDelete deleteFn) (*dagNode, error) {
// resetPod is only called in the event a Container was found. This implies that there is a parent pod.
func resetPod(n *dagNode, resetFunc resetFn, addToDelete deleteFn) (*dagNode, error) {
// this sets to reset but resets are overridden by deletes in the final FormulateRetryWorkflow logic.
curr, err := consumeTill(n, getTillFnNodeType(wfv1.NodeTypePod), resetFunc)
curr, err := resetUntil(n, getUntilFnNodeType(wfv1.NodeTypePod), resetFunc)
if err != nil {
return nil, err
}
addToDelete(curr.n.ID, true)
addToDelete(curr.n.ID)
children := getChildren(curr)
for childID := range children {
addToDelete(childID, true)
addToDelete(childID)
}
return curr, nil
}

func resetPath(allNodes []*dagNode, toNode string) (map[string]bool, map[string]bool, error) {
nodes, err := singularPath(allNodes, toNode)
func resetPath(allNodes []*dagNode, startNode string) (map[string]bool, map[string]bool, error) {
nodes, err := singularPath(allNodes, startNode)

curr := nodes[len(nodes)-1]
if len(nodes) > 0 {
// remove toNode
// remove startNode
nodes = nodes[:len(nodes)-1]
}

Expand All @@ -1046,17 +1050,13 @@ func resetPath(allNodes []*dagNode, toNode string) (map[string]bool, map[string]
return nodesToReset, nodesToDelete, nil
}

addToReset := func(nodeID string, addToNode bool) {
if nodeID == toNode && !addToNode {
return
}
// safe to reset the startNode since deletions
// override resets.
addToReset := func(nodeID string) {
nodesToReset[nodeID] = true
}

addToDelete := func(nodeID string, addToNode bool) {
if nodeID == toNode && !addToNode {
return
}
addToDelete := func(nodeID string) {
nodesToDelete[nodeID] = true
}

Expand All @@ -1081,27 +1081,27 @@ func resetPath(allNodes []*dagNode, toNode string) (map[string]bool, map[string]
case wfv1.NodeTypeContainer:
//ignore
case wfv1.NodeTypeSteps:
addToReset(curr.n.ID, false)
addToReset(curr.n.ID)
findBoundaries = true
case wfv1.NodeTypeStepGroup:
addToReset(curr.n.ID, false)
addToReset(curr.n.ID)
findBoundaries = true
case wfv1.NodeTypeDAG:
addToReset(curr.n.ID, false)
addToReset(curr.n.ID)
findBoundaries = true
case wfv1.NodeTypeTaskGroup:
addToReset(curr.n.ID, false)
addToReset(curr.n.ID)
findBoundaries = true
case wfv1.NodeTypeRetry:
addToReset(curr.n.ID, false)
addToReset(curr.n.ID)
case wfv1.NodeTypeSkipped:
// ignore -> doesn't make sense to reach this
case wfv1.NodeTypeSuspend:
// ignore
case wfv1.NodeTypeHTTP:
// ignore
case wfv1.NodeTypePlugin:
addToReset(curr.n.ID, false)
addToReset(curr.n.ID)
}

if mustFind == "" && !findBoundaries {
Expand All @@ -1110,7 +1110,7 @@ func resetPath(allNodes []*dagNode, toNode string) (map[string]bool, map[string]
}

if findBoundaries {
curr, err = consumeBoundaries(curr, addToReset)
curr, err = resetBoundaries(curr, addToReset)
if err != nil {
return nil, nil, err
}
Expand All @@ -1120,15 +1120,15 @@ func resetPath(allNodes []*dagNode, toNode string) (map[string]bool, map[string]

switch mustFind {
case wfv1.NodeTypePod:
curr, err = consumePod(curr, addToReset, addToDelete)
curr, err = resetPod(curr, addToReset, addToDelete)
case wfv1.NodeTypeSteps:
curr, err = consumeSteps(curr, addToReset)
curr, err = resetSteps(curr, addToReset)
case wfv1.NodeTypeStepGroup:
curr, err = consumeStepGroup(curr, addToReset)
curr, err = resetStepGroup(curr, addToReset)
case wfv1.NodeTypeDAG:
curr, err = consumeDAG(curr, addToReset)
curr, err = resetDAG(curr, addToReset)
case wfv1.NodeTypeTaskGroup:
curr, err = consumeTaskGroup(curr, addToReset)
curr, err = resetTaskGroup(curr, addToReset)
default:
return nil, nil, fmt.Errorf("invalid mustFind of %s supplied", mustFind)
}
Expand Down

0 comments on commit 074de52

Please sign in to comment.