From 86b7bf13a54f6f3fb26da44b79cac41d16fc2773 Mon Sep 17 00:00:00 2001 From: Samsondeen Dare Date: Thu, 23 Jan 2025 09:46:02 +0100 Subject: [PATCH] Use context instead of channel to signal cancellation from the caller to child nodes in the graph --- internal/dag/dag.go | 9 ++++-- internal/dag/dag_test.go | 5 ++-- internal/dag/walk.go | 44 ++++++++++++++++++++---------- internal/dag/walk_test.go | 27 +++++++++--------- internal/terraform/eval_context.go | 2 +- internal/terraform/graph.go | 20 ++++++++++---- 6 files changed, 68 insertions(+), 39 deletions(-) diff --git a/internal/dag/dag.go b/internal/dag/dag.go index dfb9ef95708e..3e71c5f00683 100644 --- a/internal/dag/dag.go +++ b/internal/dag/dag.go @@ -4,6 +4,7 @@ package dag import ( + "context" "errors" "fmt" "sort" @@ -272,9 +273,11 @@ func (g *AcyclicGraph) Cycles() [][]Vertex { // Walk walks the graph, calling your callback as each node is visited. // This will walk nodes in parallel if it can. The resulting diagnostics // contains problems from all graphs visited, in no particular order. -func (g *AcyclicGraph) Walk(cb WalkFunc) tfdiags.Diagnostics { - w := &Walker{Callback: cb, Reverse: true} - w.Update(g) +// The context here will be inherited by all nodes in the graph, and +// the caller can use this to signal cancellation of the graph walk. +func (g *AcyclicGraph) Walk(ctx context.Context, cb WalkFunc) tfdiags.Diagnostics { + w := NewWalker(cb) + w.Update(ctx, g) return w.Wait() } diff --git a/internal/dag/dag_test.go b/internal/dag/dag_test.go index 43c8c5d2c430..972d22c1a393 100644 --- a/internal/dag/dag_test.go +++ b/internal/dag/dag_test.go @@ -4,6 +4,7 @@ package dag import ( + "context" "flag" "fmt" "os" @@ -396,7 +397,7 @@ func TestAcyclicGraphWalk(t *testing.T) { var visits []Vertex var lock sync.Mutex - err := g.Walk(func(v Vertex) tfdiags.Diagnostics { + err := g.Walk(context.Background(), func(v Vertex) tfdiags.Diagnostics { lock.Lock() defer lock.Unlock() visits = append(visits, v) @@ -431,7 +432,7 @@ func TestAcyclicGraphWalk_error(t *testing.T) { var visits []Vertex var lock sync.Mutex - err := g.Walk(func(v Vertex) tfdiags.Diagnostics { + err := g.Walk(context.Background(), func(v Vertex) tfdiags.Diagnostics { lock.Lock() defer lock.Unlock() diff --git a/internal/dag/walk.go b/internal/dag/walk.go index 07811dad97a1..271aa1a3e11e 100644 --- a/internal/dag/walk.go +++ b/internal/dag/walk.go @@ -4,6 +4,7 @@ package dag import ( + "context" "errors" "log" "sync" @@ -76,6 +77,14 @@ func (w *Walker) init() { } } +// NewWalker creates a new walker with the given callback function. +func NewWalker(cb WalkFunc) *Walker { + // Reverse is true by default, so that the default behavior is for + // the source of an edge to depend on the target. + w := &Walker{Callback: cb, Reverse: true} + return w +} + type walkerVertex struct { // These should only be set once on initialization and never written again. // They are not protected by a lock since they don't need to be since @@ -83,12 +92,17 @@ type walkerVertex struct { // DoneCh is closed when this vertex has completed execution, regardless // of success. - // - // CancelCh is closed when the vertex should cancel execution. If execution - // is already complete (DoneCh is closed), this has no effect. Otherwise, - // execution is cancelled as quickly as possible. - DoneCh chan struct{} - CancelCh chan struct{} + DoneCh chan struct{} + + // CancelCtx is the context used to signal cancellation of the vertex's execution. + // It is created during initialization and should not be modified thereafter. + // If execution is already complete (DoneCh is closed), this has no effect. + // Otherwise, execution is cancelled as quickly as possible. + CancelCtx context.Context + // cancelCtxFn is the function used to cancel the CancelCtx. It is called to + // signal that the vertex's execution should be cancelled. This function should + // only be called once and should not be modified after initialization. + cancelCtxFn context.CancelFunc // Dependency information. Any changes to any of these fields requires // holding DepsLock. @@ -145,7 +159,7 @@ func (w *Walker) Wait() tfdiags.Diagnostics { // // Multiple Updates can be called in parallel. Update can be called at any // time during a walk. -func (w *Walker) Update(g *AcyclicGraph) { +func (w *Walker) Update(ctx context.Context, g *AcyclicGraph) { w.init() v := make(Set) e := make(Set) @@ -181,10 +195,12 @@ func (w *Walker) Update(g *AcyclicGraph) { w.vertices.Add(raw) // Initialize the vertex info + ctx, cancelVertex := context.WithCancel(ctx) info := &walkerVertex{ - DoneCh: make(chan struct{}), - CancelCh: make(chan struct{}), - deps: make(map[Vertex]chan struct{}), + DoneCh: make(chan struct{}), + CancelCtx: ctx, + cancelCtxFn: cancelVertex, + deps: make(map[Vertex]chan struct{}), } // Add it to the map and kick off the walk @@ -204,7 +220,7 @@ func (w *Walker) Update(g *AcyclicGraph) { } // Cancel the vertex - close(info.CancelCh) + info.cancelCtxFn() // Delete it out of the map delete(w.vertexMap, v) @@ -336,8 +352,8 @@ func (w *Walker) walkVertex(v Vertex, info *walkerVertex) { close(depsCh) for { select { - case <-info.CancelCh: - // Cancel + case <-info.CancelCtx.Done(): + // Context cancelled. return immediately. return case depsSuccess = <-depsCh: @@ -371,7 +387,7 @@ func (w *Walker) walkVertex(v Vertex, info *walkerVertex) { // If we passed dependencies, we just want to check once more that // we're not cancelled, since this can happen just as dependencies pass. select { - case <-info.CancelCh: + case <-info.CancelCtx.Done(): // Cancelled during an update while dependencies completed. return default: diff --git a/internal/dag/walk_test.go b/internal/dag/walk_test.go index ddb3b8055b1e..5721261b2af6 100644 --- a/internal/dag/walk_test.go +++ b/internal/dag/walk_test.go @@ -4,6 +4,7 @@ package dag import ( + "context" "fmt" "reflect" "sync" @@ -23,7 +24,7 @@ func TestWalker_basic(t *testing.T) { for i := 0; i < 50; i++ { var order []interface{} w := &Walker{Callback: walkCbRecord(&order)} - w.Update(&g) + w.Update(context.Background(), &g) // Wait if err := w.Wait(); err != nil { @@ -48,8 +49,8 @@ func TestWalker_updateNilGraph(t *testing.T) { for i := 0; i < 50; i++ { var order []interface{} w := &Walker{Callback: walkCbRecord(&order)} - w.Update(&g) - w.Update(nil) + w.Update(context.Background(), &g) + w.Update(context.Background(), nil) // Wait if err := w.Wait(); err != nil { @@ -84,7 +85,7 @@ func TestWalker_error(t *testing.T) { } w := &Walker{Callback: cb} - w.Update(&g) + w.Update(context.Background(), &g) // Wait if err := w.Wait(); err == nil { @@ -120,18 +121,18 @@ func TestWalker_newVertex(t *testing.T) { // Add the initial vertices w = &Walker{Callback: cb} - w.Update(&g) + w.Update(context.Background(), &g) // if 2 has been visited, the walk is complete so far <-done2 // Update the graph g.Add(3) - w.Update(&g) + w.Update(context.Background(), &g) // Update the graph again but with the same vertex g.Add(3) - w.Update(&g) + w.Update(context.Background(), &g) // Wait if err := w.Wait(); err != nil { @@ -159,7 +160,7 @@ func TestWalker_removeVertex(t *testing.T) { cb := func(v Vertex) tfdiags.Diagnostics { if v == 1 { g.Remove(2) - w.Update(&g) + w.Update(context.Background(), &g) } return recordF(v) @@ -167,7 +168,7 @@ func TestWalker_removeVertex(t *testing.T) { // Add the initial vertices w = &Walker{Callback: cb} - w.Update(&g) + w.Update(context.Background(), &g) // Wait if err := w.Wait(); err != nil { @@ -200,14 +201,14 @@ func TestWalker_newEdge(t *testing.T) { if v == 1 { g.Add(3) g.Connect(BasicEdge(3, 2)) - w.Update(&g) + w.Update(context.Background(), &g) } return diags } // Add the initial vertices w = &Walker{Callback: cb} - w.Update(&g) + w.Update(context.Background(), &g) // Wait if err := w.Wait(); err != nil { @@ -248,7 +249,7 @@ func TestWalker_removeEdge(t *testing.T) { switch v { case 1: g.RemoveEdge(BasicEdge(3, 2)) - w.Update(&g) + w.Update(context.Background(), &g) t.Logf("removed edge from 3 to 2") case 2: @@ -275,7 +276,7 @@ func TestWalker_removeEdge(t *testing.T) { // Add the initial vertices w = &Walker{Callback: cb} - w.Update(&g) + w.Update(context.Background(), &g) // Wait if diags := w.Wait(); diags.HasErrors() { diff --git a/internal/terraform/eval_context.go b/internal/terraform/eval_context.go index bde1626ca3db..d9cefd246c26 100644 --- a/internal/terraform/eval_context.go +++ b/internal/terraform/eval_context.go @@ -32,7 +32,7 @@ type hookFunc func(func(Hook) (HookAction, error)) error // EvalContext is the interface that is given to eval nodes to execute. type EvalContext interface { - // Stopped returns a context that is canceled when evaluation is stopped via + // StopCtx returns a context that is canceled when evaluation is stopped via // Terraform.Context.Stop() StopCtx() context.Context diff --git a/internal/terraform/graph.go b/internal/terraform/graph.go index d641fb57b4d8..ebe5ee710dd5 100644 --- a/internal/terraform/graph.go +++ b/internal/terraform/graph.go @@ -4,6 +4,7 @@ package terraform import ( + "context" "fmt" "log" "strings" @@ -40,7 +41,7 @@ func (g *Graph) Walk(walker GraphWalker) tfdiags.Diagnostics { func (g *Graph) walk(walker GraphWalker) tfdiags.Diagnostics { // The callbacks for enter/exiting a graph - ctx := walker.EvalContext() + evalCtx := walker.EvalContext() // Walk the graph. walkFn := func(v dag.Vertex) (diags tfdiags.Diagnostics) { @@ -72,7 +73,7 @@ func (g *Graph) walk(walker GraphWalker) tfdiags.Diagnostics { } }() - haveOverrides := !ctx.Overrides().Empty() + haveOverrides := !evalCtx.Overrides().Empty() // If the graph node is overridable, we'll check our overrides to see // if we need to apply any overrides to the node. @@ -85,7 +86,7 @@ func (g *Graph) walk(walker GraphWalker) tfdiags.Diagnostics { // // See the output node for an example of providing the overrides // directly to the node. - if override, ok := ctx.Overrides().GetResourceOverride(overridable.ResourceInstanceAddr(), overridable.ConfigProvider()); ok { + if override, ok := evalCtx.Overrides().GetResourceOverride(overridable.ResourceInstanceAddr(), overridable.ConfigProvider()); ok { overridable.SetOverride(override) } } @@ -99,7 +100,7 @@ func (g *Graph) walk(walker GraphWalker) tfdiags.Diagnostics { // UnkeyedInstanceShim is used by legacy provider configs within a // module to return an instance of that module, since they can never // exist within an expanded instance. - if ctx.Overrides().IsOverridden(addr.Module.UnkeyedInstanceShim()) { + if evalCtx.Overrides().IsOverridden(addr.Module.UnkeyedInstanceShim()) { log.Printf("[DEBUG] skipping provider %s found within overridden module", addr) return } @@ -112,7 +113,7 @@ func (g *Graph) walk(walker GraphWalker) tfdiags.Diagnostics { // all intentionally mutually-exclusive by having the same method // name but different signatures, since a node can only belong to // one context at a time.) - vertexCtx := ctx + vertexCtx := evalCtx if pn, ok := v.(graphNodeEvalContextScope); ok { scope := pn.Path() log.Printf("[TRACE] vertex %q: belongs to %s", dag.VertexName(v), scope) @@ -199,7 +200,14 @@ func (g *Graph) walk(walker GraphWalker) tfdiags.Diagnostics { return } - return g.AcyclicGraph.Walk(walkFn) + // This context is used to pass down the current context to the + // graph nodes. Each node inherits this context and the main + // function can use it to signal cancellation of the walk. + // We don't have any requirement to cancel the walk at this + // time, but we pass it down anyway. + ctx := context.Background() + + return g.AcyclicGraph.Walk(ctx, walkFn) } // ResourceGraph derives a graph containing addresses of only the nodes in the