Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use context instead of channel to signal cancellation from the caller to child nodes in the graph #36391

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions internal/dag/dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package dag

import (
"context"
"errors"
"fmt"
"sort"
Expand Down Expand Up @@ -272,9 +273,11 @@ func (g *AcyclicGraph) Cycles() [][]Vertex {
// Walk walks the graph, calling your callback as each node is visited.
// This will walk nodes in parallel if it can. The resulting diagnostics
// contains problems from all graphs visited, in no particular order.
func (g *AcyclicGraph) Walk(cb WalkFunc) tfdiags.Diagnostics {
w := &Walker{Callback: cb, Reverse: true}
w.Update(g)
// The context here will be inherited by all nodes in the graph, and
// the caller can use this to signal cancellation of the graph walk.
func (g *AcyclicGraph) Walk(ctx context.Context, cb WalkFunc) tfdiags.Diagnostics {
w := NewWalker(cb)
w.Update(ctx, g)
return w.Wait()
}

Expand Down
5 changes: 3 additions & 2 deletions internal/dag/dag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package dag

import (
"context"
"flag"
"fmt"
"os"
Expand Down Expand Up @@ -396,7 +397,7 @@ func TestAcyclicGraphWalk(t *testing.T) {

var visits []Vertex
var lock sync.Mutex
err := g.Walk(func(v Vertex) tfdiags.Diagnostics {
err := g.Walk(context.Background(), func(v Vertex) tfdiags.Diagnostics {
lock.Lock()
defer lock.Unlock()
visits = append(visits, v)
Expand Down Expand Up @@ -431,7 +432,7 @@ func TestAcyclicGraphWalk_error(t *testing.T) {

var visits []Vertex
var lock sync.Mutex
err := g.Walk(func(v Vertex) tfdiags.Diagnostics {
err := g.Walk(context.Background(), func(v Vertex) tfdiags.Diagnostics {
lock.Lock()
defer lock.Unlock()

Expand Down
44 changes: 30 additions & 14 deletions internal/dag/walk.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package dag

import (
"context"
"errors"
"log"
"sync"
Expand Down Expand Up @@ -76,19 +77,32 @@ func (w *Walker) init() {
}
}

// NewWalker creates a new walker with the given callback function.
func NewWalker(cb WalkFunc) *Walker {
// Reverse is true by default, so that the default behavior is for
// the source of an edge to depend on the target.
w := &Walker{Callback: cb, Reverse: true}
return w
}

type walkerVertex struct {
// These should only be set once on initialization and never written again.
// They are not protected by a lock since they don't need to be since
// they are write-once.

// DoneCh is closed when this vertex has completed execution, regardless
// of success.
//
// CancelCh is closed when the vertex should cancel execution. If execution
// is already complete (DoneCh is closed), this has no effect. Otherwise,
// execution is cancelled as quickly as possible.
DoneCh chan struct{}
CancelCh chan struct{}
DoneCh chan struct{}

// CancelCtx is the context used to signal cancellation of the vertex's execution.
// It is created during initialization and should not be modified thereafter.
// If execution is already complete (DoneCh is closed), this has no effect.
// Otherwise, execution is cancelled as quickly as possible.
CancelCtx context.Context
// cancelCtxFn is the function used to cancel the CancelCtx. It is called to
// signal that the vertex's execution should be cancelled. This function should
// only be called once and should not be modified after initialization.
cancelCtxFn context.CancelFunc

// Dependency information. Any changes to any of these fields requires
// holding DepsLock.
Expand Down Expand Up @@ -145,7 +159,7 @@ func (w *Walker) Wait() tfdiags.Diagnostics {
//
// Multiple Updates can be called in parallel. Update can be called at any
// time during a walk.
func (w *Walker) Update(g *AcyclicGraph) {
func (w *Walker) Update(ctx context.Context, g *AcyclicGraph) {
w.init()
v := make(Set)
e := make(Set)
Expand Down Expand Up @@ -181,10 +195,12 @@ func (w *Walker) Update(g *AcyclicGraph) {
w.vertices.Add(raw)

// Initialize the vertex info
ctx, cancelVertex := context.WithCancel(ctx)
info := &walkerVertex{
DoneCh: make(chan struct{}),
CancelCh: make(chan struct{}),
deps: make(map[Vertex]chan struct{}),
DoneCh: make(chan struct{}),
CancelCtx: ctx,
cancelCtxFn: cancelVertex,
deps: make(map[Vertex]chan struct{}),
}

// Add it to the map and kick off the walk
Expand All @@ -204,7 +220,7 @@ func (w *Walker) Update(g *AcyclicGraph) {
}

// Cancel the vertex
close(info.CancelCh)
info.cancelCtxFn()

// Delete it out of the map
delete(w.vertexMap, v)
Expand Down Expand Up @@ -336,8 +352,8 @@ func (w *Walker) walkVertex(v Vertex, info *walkerVertex) {
close(depsCh)
for {
select {
case <-info.CancelCh:
// Cancel
case <-info.CancelCtx.Done():
// Context cancelled. return immediately.
return

case depsSuccess = <-depsCh:
Expand Down Expand Up @@ -371,7 +387,7 @@ func (w *Walker) walkVertex(v Vertex, info *walkerVertex) {
// If we passed dependencies, we just want to check once more that
// we're not cancelled, since this can happen just as dependencies pass.
select {
case <-info.CancelCh:
case <-info.CancelCtx.Done():
// Cancelled during an update while dependencies completed.
return
default:
Expand Down
27 changes: 14 additions & 13 deletions internal/dag/walk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package dag

import (
"context"
"fmt"
"reflect"
"sync"
Expand All @@ -23,7 +24,7 @@ func TestWalker_basic(t *testing.T) {
for i := 0; i < 50; i++ {
var order []interface{}
w := &Walker{Callback: walkCbRecord(&order)}
w.Update(&g)
w.Update(context.Background(), &g)

// Wait
if err := w.Wait(); err != nil {
Expand All @@ -48,8 +49,8 @@ func TestWalker_updateNilGraph(t *testing.T) {
for i := 0; i < 50; i++ {
var order []interface{}
w := &Walker{Callback: walkCbRecord(&order)}
w.Update(&g)
w.Update(nil)
w.Update(context.Background(), &g)
w.Update(context.Background(), nil)

// Wait
if err := w.Wait(); err != nil {
Expand Down Expand Up @@ -84,7 +85,7 @@ func TestWalker_error(t *testing.T) {
}

w := &Walker{Callback: cb}
w.Update(&g)
w.Update(context.Background(), &g)

// Wait
if err := w.Wait(); err == nil {
Expand Down Expand Up @@ -120,18 +121,18 @@ func TestWalker_newVertex(t *testing.T) {

// Add the initial vertices
w = &Walker{Callback: cb}
w.Update(&g)
w.Update(context.Background(), &g)

// if 2 has been visited, the walk is complete so far
<-done2

// Update the graph
g.Add(3)
w.Update(&g)
w.Update(context.Background(), &g)

// Update the graph again but with the same vertex
g.Add(3)
w.Update(&g)
w.Update(context.Background(), &g)

// Wait
if err := w.Wait(); err != nil {
Expand Down Expand Up @@ -159,15 +160,15 @@ func TestWalker_removeVertex(t *testing.T) {
cb := func(v Vertex) tfdiags.Diagnostics {
if v == 1 {
g.Remove(2)
w.Update(&g)
w.Update(context.Background(), &g)
}

return recordF(v)
}

// Add the initial vertices
w = &Walker{Callback: cb}
w.Update(&g)
w.Update(context.Background(), &g)

// Wait
if err := w.Wait(); err != nil {
Expand Down Expand Up @@ -200,14 +201,14 @@ func TestWalker_newEdge(t *testing.T) {
if v == 1 {
g.Add(3)
g.Connect(BasicEdge(3, 2))
w.Update(&g)
w.Update(context.Background(), &g)
}
return diags
}

// Add the initial vertices
w = &Walker{Callback: cb}
w.Update(&g)
w.Update(context.Background(), &g)

// Wait
if err := w.Wait(); err != nil {
Expand Down Expand Up @@ -248,7 +249,7 @@ func TestWalker_removeEdge(t *testing.T) {
switch v {
case 1:
g.RemoveEdge(BasicEdge(3, 2))
w.Update(&g)
w.Update(context.Background(), &g)
t.Logf("removed edge from 3 to 2")

case 2:
Expand All @@ -275,7 +276,7 @@ func TestWalker_removeEdge(t *testing.T) {

// Add the initial vertices
w = &Walker{Callback: cb}
w.Update(&g)
w.Update(context.Background(), &g)

// Wait
if diags := w.Wait(); diags.HasErrors() {
Expand Down
2 changes: 1 addition & 1 deletion internal/terraform/eval_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type hookFunc func(func(Hook) (HookAction, error)) error

// EvalContext is the interface that is given to eval nodes to execute.
type EvalContext interface {
// Stopped returns a context that is canceled when evaluation is stopped via
// StopCtx returns a context that is canceled when evaluation is stopped via
// Terraform.Context.Stop()
StopCtx() context.Context

Expand Down
20 changes: 14 additions & 6 deletions internal/terraform/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package terraform

import (
"context"
"fmt"
"log"
"strings"
Expand Down Expand Up @@ -40,7 +41,7 @@ func (g *Graph) Walk(walker GraphWalker) tfdiags.Diagnostics {

func (g *Graph) walk(walker GraphWalker) tfdiags.Diagnostics {
// The callbacks for enter/exiting a graph
ctx := walker.EvalContext()
evalCtx := walker.EvalContext()

// Walk the graph.
walkFn := func(v dag.Vertex) (diags tfdiags.Diagnostics) {
Expand Down Expand Up @@ -72,7 +73,7 @@ func (g *Graph) walk(walker GraphWalker) tfdiags.Diagnostics {
}
}()

haveOverrides := !ctx.Overrides().Empty()
haveOverrides := !evalCtx.Overrides().Empty()

// If the graph node is overridable, we'll check our overrides to see
// if we need to apply any overrides to the node.
Expand All @@ -85,7 +86,7 @@ func (g *Graph) walk(walker GraphWalker) tfdiags.Diagnostics {
//
// See the output node for an example of providing the overrides
// directly to the node.
if override, ok := ctx.Overrides().GetResourceOverride(overridable.ResourceInstanceAddr(), overridable.ConfigProvider()); ok {
if override, ok := evalCtx.Overrides().GetResourceOverride(overridable.ResourceInstanceAddr(), overridable.ConfigProvider()); ok {
overridable.SetOverride(override)
}
}
Expand All @@ -99,7 +100,7 @@ func (g *Graph) walk(walker GraphWalker) tfdiags.Diagnostics {
// UnkeyedInstanceShim is used by legacy provider configs within a
// module to return an instance of that module, since they can never
// exist within an expanded instance.
if ctx.Overrides().IsOverridden(addr.Module.UnkeyedInstanceShim()) {
if evalCtx.Overrides().IsOverridden(addr.Module.UnkeyedInstanceShim()) {
log.Printf("[DEBUG] skipping provider %s found within overridden module", addr)
return
}
Expand All @@ -112,7 +113,7 @@ func (g *Graph) walk(walker GraphWalker) tfdiags.Diagnostics {
// all intentionally mutually-exclusive by having the same method
// name but different signatures, since a node can only belong to
// one context at a time.)
vertexCtx := ctx
vertexCtx := evalCtx
if pn, ok := v.(graphNodeEvalContextScope); ok {
scope := pn.Path()
log.Printf("[TRACE] vertex %q: belongs to %s", dag.VertexName(v), scope)
Expand Down Expand Up @@ -199,7 +200,14 @@ func (g *Graph) walk(walker GraphWalker) tfdiags.Diagnostics {
return
}

return g.AcyclicGraph.Walk(walkFn)
// This context is used to pass down the current context to the
// graph nodes. Each node inherits this context and the main
// function can use it to signal cancellation of the walk.
// We don't have any requirement to cancel the walk at this
// time, but we pass it down anyway.
ctx := context.Background()

return g.AcyclicGraph.Walk(ctx, walkFn)
}

// ResourceGraph derives a graph containing addresses of only the nodes in the
Expand Down
Loading