Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(graph): add traverse methods container start and shutdown #5508

Merged
merged 14 commits into from
Dec 6, 2023
232 changes: 232 additions & 0 deletions internal/pkg/graph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
// Package graph provides functionality for directed graphs.
package graph

import (
"context"
"sync"

"golang.org/x/sync/errgroup"
)

// vertexStatus denotes the visiting status of a vertex when running DFS in a graph.
type vertexStatus int

Expand Down Expand Up @@ -215,3 +222,228 @@ func TopologicalOrder[V comparable](digraph *Graph[V]) (*TopologicalSorter[V], e
topo.traverse(digraph)
return topo, nil
}

// LabeledGraph extends a generic Graph by associating a label (or status) with each vertex.
// It is concurrency-safe, utilizing a mutex lock for synchronized access.
type LabeledGraph[V comparable] struct {
*Graph[V]
status map[V]string
lock sync.Mutex
}

// NewLabeledGraph initializes a LabeledGraph with specified vertices and optional configurations.
// It creates a base Graph with the vertices and applies any LabeledGraphOption to configure additional properties.
func NewLabeledGraph[V comparable](vertices []V, opts ...LabeledGraphOption[V]) *LabeledGraph[V] {
g := New(vertices...)
lg := &LabeledGraph[V]{
Graph: g,
status: make(map[V]string),
}
for _, opt := range opts {
opt(lg)
}
return lg
}

// LabeledGraphOption allows you to initialize Graph with additional properties.
type LabeledGraphOption[V comparable] func(g *LabeledGraph[V])

// WithStatus sets the status of each vertex in the Graph.
func WithStatus[V comparable](status string) func(g *LabeledGraph[V]) {
return func(g *LabeledGraph[V]) {
g.status = make(map[V]string)
for vertex := range g.vertices {
g.status[vertex] = status
}
}
}

// updateStatus updates the status of a vertex.
func (lg *LabeledGraph[V]) updateStatus(vertex V, status string) {
lg.lock.Lock()
defer lg.lock.Unlock()
lg.status[vertex] = status
}

// getStatus gets the status of a vertex.
func (lg *LabeledGraph[V]) getStatus(vertex V) string {
lg.lock.Lock()
defer lg.lock.Unlock()
return lg.status[vertex]
}

// getLeaves returns the leaves of a given vertex.
func (lg *LabeledGraph[V]) leaves() []V {
lg.lock.Lock()
defer lg.lock.Unlock()
var leaves []V
for vtx := range lg.vertices {
if len(lg.vertices[vtx]) == 0 {
leaves = append(leaves, vtx)
}
}
return leaves
}

// getParents returns the parent vertices (incoming edges) of vertex.
func (lg *LabeledGraph[V]) parents(vtx V) []V {
lg.lock.Lock()
defer lg.lock.Unlock()
var parents []V
for v, neighbors := range lg.vertices {
if neighbors[vtx] {
parents = append(parents, v)
}
}
return parents
}

// getChildren returns the child vertices (outgoing edges) of vertex.
func (lg *LabeledGraph[V]) children(vtx V) []V {
lg.lock.Lock()
defer lg.lock.Unlock()
return lg.Neighbors(vtx)
}

// filterParents filters parents based on the vertex status.
func (lg *LabeledGraph[V]) filterParents(vtx V, status string) []V {
parents := lg.parents(vtx)
var filtered []V
for _, parent := range parents {
if lg.getStatus(parent) == status {
filtered = append(filtered, parent)
}
}
return filtered
}

// filterChildren filters children based on the vertex status.
func (lg *LabeledGraph[V]) filterChildren(vtx V, status string) []V {
children := lg.children(vtx)
var filtered []V
for _, child := range children {
if lg.getStatus(child) == status {
filtered = append(filtered, child)
}
}
return filtered
}

/*
UpwardTraversal performs an upward traversal on the graph starting from leaves (nodes with no children)
and moving towards root nodes (nodes with children).
It applies the specified process function to each vertex in the graph, skipping vertices with the
"adjacentVertexSkipStatus" status, and continuing traversal until reaching vertices with the "requiredVertexStatus" status.
The traversal is concurrent and may process vertices in parallel.
Returns an error if the traversal encounters any issues, or nil if successful.
*/
func (lg *LabeledGraph[V]) UpwardTraversal(ctx context.Context, processVertexFunc func(context.Context, V) error, nextVertexSkipStatus, requiredVertexStatus string) error {
traversal := &graphTraversal[V]{
mu: sync.Mutex{},
seen: make(map[V]struct{}),
findStartVertices: func(lg *LabeledGraph[V]) []V { return lg.leaves() },
findNextVertices: func(lg *LabeledGraph[V], v V) []V { return lg.parents(v) },
filterPreviousVerticesByStatus: func(g *LabeledGraph[V], v V, status string) []V { return g.filterChildren(v, status) },
requiredVertexStatus: requiredVertexStatus,
nextVertexSkipStatus: nextVertexSkipStatus,
processVertex: processVertexFunc,
}
return traversal.execute(ctx, lg)
}

/*
DownwardTraversal performs a downward traversal on the graph starting from root nodes (nodes with no parents)
and moving towards leaf nodes (nodes with parents). It applies the specified process function to each
vertex in the graph, skipping vertices with the "adjacentVertexSkipStatus" status, and continuing traversal
until reaching vertices with the "requiredVertexStatus" status.
The traversal is concurrent and may process vertices in parallel.
Returns an error if the traversal encounters any issues.
*/
func (lg *LabeledGraph[V]) DownwardTraversal(ctx context.Context, processVertexFunc func(context.Context, V) error, adjacentVertexSkipStatus, requiredVertexStatus string) error {
traversal := &graphTraversal[V]{
mu: sync.Mutex{},
seen: make(map[V]struct{}),
findStartVertices: func(lg *LabeledGraph[V]) []V { return lg.Roots() },
findNextVertices: func(lg *LabeledGraph[V], v V) []V { return lg.children(v) },
filterPreviousVerticesByStatus: func(lg *LabeledGraph[V], v V, status string) []V { return lg.filterParents(v, status) },
requiredVertexStatus: requiredVertexStatus,
nextVertexSkipStatus: adjacentVertexSkipStatus,
processVertex: processVertexFunc,
}
return traversal.execute(ctx, lg)
}

type graphTraversal[V comparable] struct {
mu sync.Mutex
seen map[V]struct{}
findStartVertices func(*LabeledGraph[V]) []V
findNextVertices func(*LabeledGraph[V], V) []V
filterPreviousVerticesByStatus func(*LabeledGraph[V], V, string) []V
requiredVertexStatus string
nextVertexSkipStatus string
processVertex func(context.Context, V) error
}

func (t *graphTraversal[V]) execute(ctx context.Context, lg *LabeledGraph[V]) error {

ctx, cancel := context.WithCancel(ctx)
defer cancel()

vertexCount := len(lg.vertices)
if vertexCount == 0 {
return nil
}
eg, ctx := errgroup.WithContext(ctx)
vertexCh := make(chan V, vertexCount)
defer close(vertexCh)

processVertices := func(ctx context.Context, graph *LabeledGraph[V], eg *errgroup.Group, vertices []V, vertexCh chan V) {
for _, vertex := range vertices {
vertex := vertex
// Delay processing this vertex if any of its dependent vertices are yet to be processed.
if len(t.filterPreviousVerticesByStatus(graph, vertex, t.nextVertexSkipStatus)) != 0 {
continue
}
if !t.markAsSeen(vertex) {
// Skip this vertex if it's already been processed by another routine.
continue
}
eg.Go(func() error {
if err := t.processVertex(ctx, vertex); err != nil {
return err
}
// Assign new status to the vertex upon successful processing.
graph.updateStatus(vertex, t.requiredVertexStatus)
vertexCh <- vertex
return nil
})
}
}

eg.Go(func() error {
iamhopaul123 marked this conversation as resolved.
Show resolved Hide resolved
for {
select {
case <-ctx.Done():
return ctx.Err()
case vertex := <-vertexCh:
vertexCount--
if vertexCount == 0 {
return nil
}
processVertices(ctx, lg, eg, t.findNextVertices(lg, vertex), vertexCh)
}
}
})
processVertices(ctx, lg, eg, t.findStartVertices(lg), vertexCh)
return eg.Wait()
}

func (t *graphTraversal[V]) markAsSeen(vertex V) bool {
t.mu.Lock()
defer t.mu.Unlock()
if _, seen := t.seen[vertex]; seen {
return false
}
t.seen[vertex] = struct{}{}
return true
}
69 changes: 69 additions & 0 deletions internal/pkg/graph/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package graph

import (
"context"
"strings"
"testing"

Expand Down Expand Up @@ -369,3 +370,71 @@ func TestTopologicalOrder(t *testing.T) {
})
}
}

func buildGraphWithSingleParent() *LabeledGraph[string] {
vertices := []string{"A", "B", "C", "D"}
graph := NewLabeledGraph[string](vertices, WithStatus[string]("started"))
graph.Add(Edge[string]{From: "D", To: "C"}) // D -> C
graph.Add(Edge[string]{From: "C", To: "B"}) // C -> B
graph.Add(Edge[string]{From: "B", To: "A"}) // B -> A
return graph
}

func TestTraverseInDependencyOrder(t *testing.T) {
t.Run("graph with single root vertex", func(t *testing.T) {
graph := buildGraphWithSingleParent()
var visited []string
processFn := func(ctx context.Context, v string) error {
visited = append(visited, v)
return nil
}
err := graph.UpwardTraversal(context.Background(), processFn, "started", "stopped")
require.NoError(t, err)
expected := []string{"A", "B", "C", "D"}
require.Equal(t, expected, visited)
})
t.Run("graph with multiple parents and boundary nodes", func(t *testing.T) {
vertices := []string{"A", "B", "C", "D"}
graph := NewLabeledGraph[string](vertices, WithStatus[string]("started"))
graph.Add(Edge[string]{From: "A", To: "C"})
graph.Add(Edge[string]{From: "A", To: "D"})
graph.Add(Edge[string]{From: "B", To: "D"})
vtxChan := make(chan string, 4)
seen := make(map[string]int)
done := make(chan struct{})
go func() {
for _, vtx := range vertices {
seen[vtx]++
}
close(done)
}()

err := graph.DownwardTraversal(context.Background(), func(ctx context.Context, vtx string) error {
vtxChan <- vtx
return nil
}, "started", "stopped")
require.NoError(t, err, "Error during iteration")
close(vtxChan)
<-done

require.Len(t, seen, 4)
for vtx, count := range seen {
require.Equal(t, 1, count, "%s", vtx)
}
})
}

func TestTraverseInReverseDependencyOrder(t *testing.T) {
t.Run("Graph with single root vertex", func(t *testing.T) {
graph := buildGraphWithSingleParent()
var visited []string
processFn := func(ctx context.Context, v string) error {
visited = append(visited, v)
return nil
}
err := graph.DownwardTraversal(context.Background(), processFn, "started", "stopped")
require.NoError(t, err)
expected := []string{"D", "C", "B", "A"}
require.Equal(t, expected, visited)
})
}
Loading