Skip to content

Commit

Permalink
fix: order DAG steps before inserting
Browse files Browse the repository at this point in the history
  • Loading branch information
abelanger5 committed Feb 13, 2025
1 parent ee623af commit e434ece
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 2 deletions.
62 changes: 62 additions & 0 deletions internal/dagutils/order.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package dagutils

import (
"fmt"

"github.com/hatchet-dev/hatchet/pkg/repository"
)

func OrderWorkflowSteps(steps []repository.CreateWorkflowStepOpts) ([]repository.CreateWorkflowStepOpts, error) {
// Build a map of step id to step for quick lookup.
stepMap := make(map[string]repository.CreateWorkflowStepOpts)
for _, step := range steps {
stepMap[step.ReadableId] = step
}

// Initialize in-degree map and adjacency list graph.
inDegree := make(map[string]int)
graph := make(map[string][]string)
for _, step := range steps {
inDegree[step.ReadableId] = 0
}

// Build the graph and compute in-degrees.
for _, step := range steps {
for _, parent := range step.Parents {
if _, exists := stepMap[parent]; !exists {
return nil, fmt.Errorf("unknown parent step: %s", parent)
}
graph[parent] = append(graph[parent], step.ReadableId)
inDegree[step.ReadableId]++
}
}

// Queue for steps with no incoming edges.
var queue []string
for id, degree := range inDegree {
if degree == 0 {
queue = append(queue, id)
}
}

var ordered []repository.CreateWorkflowStepOpts
// Process the steps in topological order.
for len(queue) > 0 {
id := queue[0]
queue = queue[1:]
ordered = append(ordered, stepMap[id])
for _, child := range graph[id] {
inDegree[child]--
if inDegree[child] == 0 {
queue = append(queue, child)
}
}
}

// If not all steps are processed, there is a cycle.
if len(ordered) != len(steps) {
return nil, fmt.Errorf("cycle detected in workflow steps")
}

return ordered, nil
}
89 changes: 89 additions & 0 deletions internal/dagutils/order_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package dagutils_test

import (
"testing"

"github.com/hatchet-dev/hatchet/internal/dagutils"
"github.com/hatchet-dev/hatchet/pkg/repository"
)

func TestOrderWorkflowSteps(t *testing.T) {
t.Run("valid ordering", func(t *testing.T) {
steps := []repository.CreateWorkflowStepOpts{
{
ReadableId: "step1",
Action: "action1",
Parents: []string{},
},
{
ReadableId: "step2",
Action: "action2",
Parents: []string{"step1"},
},
{
ReadableId: "step3",
Action: "action3",
Parents: []string{"step2"},
},
}

ordered, err := dagutils.OrderWorkflowSteps(steps)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// Validate that each step appears after its parents.
orderIndex := make(map[string]int)
for i, step := range ordered {
orderIndex[step.ReadableId] = i
}

for _, step := range steps {
for _, parent := range step.Parents {
if orderIndex[parent] > orderIndex[step.ReadableId] {
t.Errorf("step %q appears before its parent %q", step.ReadableId, parent)
}
}
}
})

t.Run("unknown parent", func(t *testing.T) {
steps := []repository.CreateWorkflowStepOpts{
{
ReadableId: "step1",
Action: "action1",
Parents: []string{"nonexistent"},
},
}

_, err := dagutils.OrderWorkflowSteps(steps)
if err == nil {
t.Fatal("expected error for unknown parent, got nil")
}
})

t.Run("cycle detection", func(t *testing.T) {
steps := []repository.CreateWorkflowStepOpts{
{
ReadableId: "step1",
Action: "action1",
Parents: []string{"step3"},
},
{
ReadableId: "step2",
Action: "action2",
Parents: []string{"step1"},
},
{
ReadableId: "step3",
Action: "action3",
Parents: []string{"step2"},
},
}

_, err := dagutils.OrderWorkflowSteps(steps)
if err == nil {
t.Fatal("expected error for cycle detection, got nil")
}
})
}
18 changes: 16 additions & 2 deletions pkg/repository/prisma/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -461,12 +461,19 @@ func (r *workflowEngineRepository) CreateNewWorkflow(ctx context.Context, tenant
}

// ensure no cycles
for _, job := range opts.Jobs {
for i, job := range opts.Jobs {
if dagutils.HasCycle(job.Steps) {
return nil, &repository.JobRunHasCycleError{
JobName: job.Name,
}
}

var err error
opts.Jobs[i].Steps, err = dagutils.OrderWorkflowSteps(job.Steps)

if err != nil {
return nil, err
}
}

// preflight check to ensure the workflow doesn't already exist
Expand Down Expand Up @@ -572,12 +579,19 @@ func (r *workflowEngineRepository) CreateWorkflowVersion(ctx context.Context, te
}

// ensure no cycles
for _, job := range opts.Jobs {
for i, job := range opts.Jobs {
if dagutils.HasCycle(job.Steps) {
return nil, &repository.JobRunHasCycleError{
JobName: job.Name,
}
}

var err error
opts.Jobs[i].Steps, err = dagutils.OrderWorkflowSteps(job.Steps)

if err != nil {
return nil, err
}
}

// preflight check to ensure the workflow already exists
Expand Down

0 comments on commit e434ece

Please sign in to comment.