diff --git a/internal/dagutils/order.go b/internal/dagutils/order.go new file mode 100644 index 000000000..b77f72c4b --- /dev/null +++ b/internal/dagutils/order.go @@ -0,0 +1,62 @@ +package dagutils + +import ( + "fmt" + + "github.com/hatchet-dev/hatchet/pkg/repository" +) + +func OrderWorkflowSteps(steps []repository.CreateWorkflowStepOpts) ([]repository.CreateWorkflowStepOpts, error) { + // Build a map of step id to step for quick lookup. + stepMap := make(map[string]repository.CreateWorkflowStepOpts) + for _, step := range steps { + stepMap[step.ReadableId] = step + } + + // Initialize in-degree map and adjacency list graph. + inDegree := make(map[string]int) + graph := make(map[string][]string) + for _, step := range steps { + inDegree[step.ReadableId] = 0 + } + + // Build the graph and compute in-degrees. + for _, step := range steps { + for _, parent := range step.Parents { + if _, exists := stepMap[parent]; !exists { + return nil, fmt.Errorf("unknown parent step: %s", parent) + } + graph[parent] = append(graph[parent], step.ReadableId) + inDegree[step.ReadableId]++ + } + } + + // Queue for steps with no incoming edges. + var queue []string + for id, degree := range inDegree { + if degree == 0 { + queue = append(queue, id) + } + } + + var ordered []repository.CreateWorkflowStepOpts + // Process the steps in topological order. + for len(queue) > 0 { + id := queue[0] + queue = queue[1:] + ordered = append(ordered, stepMap[id]) + for _, child := range graph[id] { + inDegree[child]-- + if inDegree[child] == 0 { + queue = append(queue, child) + } + } + } + + // If not all steps are processed, there is a cycle. + if len(ordered) != len(steps) { + return nil, fmt.Errorf("cycle detected in workflow steps") + } + + return ordered, nil +} diff --git a/internal/dagutils/order_test.go b/internal/dagutils/order_test.go new file mode 100644 index 000000000..d247bf334 --- /dev/null +++ b/internal/dagutils/order_test.go @@ -0,0 +1,89 @@ +package dagutils_test + +import ( + "testing" + + "github.com/hatchet-dev/hatchet/internal/dagutils" + "github.com/hatchet-dev/hatchet/pkg/repository" +) + +func TestOrderWorkflowSteps(t *testing.T) { + t.Run("valid ordering", func(t *testing.T) { + steps := []repository.CreateWorkflowStepOpts{ + { + ReadableId: "step1", + Action: "action1", + Parents: []string{}, + }, + { + ReadableId: "step2", + Action: "action2", + Parents: []string{"step1"}, + }, + { + ReadableId: "step3", + Action: "action3", + Parents: []string{"step2"}, + }, + } + + ordered, err := dagutils.OrderWorkflowSteps(steps) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Validate that each step appears after its parents. + orderIndex := make(map[string]int) + for i, step := range ordered { + orderIndex[step.ReadableId] = i + } + + for _, step := range steps { + for _, parent := range step.Parents { + if orderIndex[parent] > orderIndex[step.ReadableId] { + t.Errorf("step %q appears before its parent %q", step.ReadableId, parent) + } + } + } + }) + + t.Run("unknown parent", func(t *testing.T) { + steps := []repository.CreateWorkflowStepOpts{ + { + ReadableId: "step1", + Action: "action1", + Parents: []string{"nonexistent"}, + }, + } + + _, err := dagutils.OrderWorkflowSteps(steps) + if err == nil { + t.Fatal("expected error for unknown parent, got nil") + } + }) + + t.Run("cycle detection", func(t *testing.T) { + steps := []repository.CreateWorkflowStepOpts{ + { + ReadableId: "step1", + Action: "action1", + Parents: []string{"step3"}, + }, + { + ReadableId: "step2", + Action: "action2", + Parents: []string{"step1"}, + }, + { + ReadableId: "step3", + Action: "action3", + Parents: []string{"step2"}, + }, + } + + _, err := dagutils.OrderWorkflowSteps(steps) + if err == nil { + t.Fatal("expected error for cycle detection, got nil") + } + }) +} diff --git a/pkg/repository/prisma/workflow.go b/pkg/repository/prisma/workflow.go index a81af6217..ad751c116 100644 --- a/pkg/repository/prisma/workflow.go +++ b/pkg/repository/prisma/workflow.go @@ -461,12 +461,19 @@ func (r *workflowEngineRepository) CreateNewWorkflow(ctx context.Context, tenant } // ensure no cycles - for _, job := range opts.Jobs { + for i, job := range opts.Jobs { if dagutils.HasCycle(job.Steps) { return nil, &repository.JobRunHasCycleError{ JobName: job.Name, } } + + var err error + opts.Jobs[i].Steps, err = dagutils.OrderWorkflowSteps(job.Steps) + + if err != nil { + return nil, err + } } // preflight check to ensure the workflow doesn't already exist @@ -572,12 +579,19 @@ func (r *workflowEngineRepository) CreateWorkflowVersion(ctx context.Context, te } // ensure no cycles - for _, job := range opts.Jobs { + for i, job := range opts.Jobs { if dagutils.HasCycle(job.Steps) { return nil, &repository.JobRunHasCycleError{ JobName: job.Name, } } + + var err error + opts.Jobs[i].Steps, err = dagutils.OrderWorkflowSteps(job.Steps) + + if err != nil { + return nil, err + } } // preflight check to ensure the workflow already exists