diff --git a/executor.go b/executor.go index 9772e6c..07d84e0 100644 --- a/executor.go +++ b/executor.go @@ -70,7 +70,8 @@ func (e *innerExecutorImpl) sche_successors(node *innerNode) { candidate := make([]*innerNode, 0, len(node.successors)) for _, n := range node.successors { - if n.JoinCounter() == 0 { + if n.JoinCounter() == 0 || n.Typ == nodeCondition { + // deps all done or condition node candidate = append(candidate, n) } } @@ -78,14 +79,14 @@ func (e *innerExecutorImpl) sche_successors(node *innerNode) { slices.SortFunc(candidate, func(i, j *innerNode) int { return cmp.Compare(i.priority, j.priority) }) - + node.setup() e.schedule(candidate...) } -func (e *innerExecutorImpl) invodeStatic(node *innerNode, parentSpan *span, p *Static) func() { +func (e *innerExecutorImpl) invokeStatic(node *innerNode, parentSpan *span, p *Static) func() { return func() { span := span{extra: attr{ - typ: NodeStatic, + typ: nodeStatic, name: node.name, }, begin: time.Now(), parent: parentSpan} @@ -99,9 +100,10 @@ func (e *innerExecutorImpl) invodeStatic(node *innerNode, parentSpan *span, p *S e.profiler.AddSpan(&span) // remove canceled node span } - e.wg.Done() node.drop() e.sche_successors(node) + node.g.joinCounter.Decrease() + e.wg.Done() node.g.scheCond.Signal() }() @@ -114,7 +116,7 @@ func (e *innerExecutorImpl) invodeStatic(node *innerNode, parentSpan *span, p *S func (e *innerExecutorImpl) invokeSubflow(node *innerNode, parentSpan *span, p *Subflow) func() { return func() { span := span{extra: attr{ - typ: NodeSubflow, + typ: nodeSubflow, name: node.name, }, begin: time.Now(), parent: parentSpan} defer func() { @@ -127,11 +129,12 @@ func (e *innerExecutorImpl) invokeSubflow(node *innerNode, parentSpan *span, p * } else { e.profiler.AddSpan(&span) // remove canceled node span } - e.wg.Done() + e.scheduleGraph(p.g, &span) node.drop() - e.sche_successors(node) + node.g.joinCounter.Decrease() + e.wg.Done() node.g.scheCond.Signal() }() @@ -147,7 +150,7 @@ func (e *innerExecutorImpl) invokeSubflow(node *innerNode, parentSpan *span, p * func (e *innerExecutorImpl) invokeCondition(node *innerNode, parentSpan *span, p *Condition) func() { return func() { span := span{extra: attr{ - typ: NodeCondition, + typ: nodeCondition, name: node.name, }, begin: time.Now(), parent: parentSpan} @@ -160,9 +163,10 @@ func (e *innerExecutorImpl) invokeCondition(node *innerNode, parentSpan *span, p } else { e.profiler.AddSpan(&span) // remove canceled node span } - e.wg.Done() node.drop() e.sche_successors(node) + node.g.joinCounter.Decrease() + e.wg.Done() node.g.scheCond.Signal() }() @@ -177,7 +181,7 @@ func (e *innerExecutorImpl) invokeCondition(node *innerNode, parentSpan *span, p if idx == choice { continue } - v.state.Store(kNodeStateCanceled) // cancel other nodes + v.state.Store(kNodeStateIgnored) // cancel other nodes } // do choice and cancel others node.state.Store(kNodeStateFinished) @@ -186,9 +190,10 @@ func (e *innerExecutorImpl) invokeCondition(node *innerNode, parentSpan *span, p func (e *innerExecutorImpl) invokeNode(node *innerNode, parentSpan *span) { // do job + fmt.Println("[invoke] ", node.name) switch p := node.ptr.(type) { case *Static: - e.pool.Go(e.invodeStatic(node, parentSpan, p)) + e.pool.Go(e.invokeStatic(node, parentSpan, p)) case *Subflow: e.pool.Go(e.invokeSubflow(node, parentSpan, p)) case *Condition: @@ -200,6 +205,7 @@ func (e *innerExecutorImpl) invokeNode(node *innerNode, parentSpan *span) { func (e *innerExecutorImpl) schedule(nodes ...*innerNode) { for _, node := range nodes { + fmt.Println("[schedule] ", node.name) if node.g.canceled.Load() { node.g.scheCond.Signal() fmt.Printf("node %v is not scheduled, as graph %v is canceled\n", node.name, node.g.name) @@ -213,7 +219,17 @@ func (e *innerExecutorImpl) schedule(nodes ...*innerNode) { v.state.Store(kNodeStateCanceled) } - return + continue + } + + if node.state.Load() == kNodeStateIgnored { + node.g.scheCond.Signal() + fmt.Printf("node %v is ignored\n", node.name) + for _, v := range node.successors { + v.state.Store(kNodeStateIdle) + } + + continue } node.g.joinCounter.Increase() diff --git a/flow.go b/flow.go index 899fb49..84e9670 100644 --- a/flow.go +++ b/flow.go @@ -51,7 +51,7 @@ func (fb *flowBuilder) NewStatic(name string, f func()) *innerNode { node.ptr = &Static{ handle: f, } - node.Typ = NodeStatic + node.Typ = nodeStatic return node } @@ -61,7 +61,7 @@ func (fb *flowBuilder) NewSubflow(name string, f func(sf *Subflow)) *innerNode { handle: f, g: newGraph(name), } - node.Typ = NodeSubflow + node.Typ = nodeSubflow return node } @@ -71,6 +71,6 @@ func (fb *flowBuilder) NewCondition(name string, f func() uint) *innerNode { handle: f, mapper: make(map[uint]*innerNode), } - node.Typ = NodeCondition + node.Typ = nodeCondition return node } diff --git a/graph.go b/graph.go index 05c22a2..41d9373 100644 --- a/graph.go +++ b/graph.go @@ -50,7 +50,7 @@ func (g *eGraph) setup() { g.reset() for _, node := range g.nodes { - node.joinCounter.Set(len(node.dependents)) + node.setup() if len(node.dependents) == 0 { g.entries = append(g.entries, node) diff --git a/node.go b/node.go index 486d309..ed737c6 100644 --- a/node.go +++ b/node.go @@ -14,21 +14,22 @@ const ( kNodeStateFinished = int32(3) kNodeStateFailed = int32(4) kNodeStateCanceled = int32(5) + kNodeStateIgnored = int32(6) ) -type NodeType string +type nodeType string const ( - NodeSubflow NodeType = "subflow" // subflow - NodeStatic NodeType = "static" // static - NodeCondition NodeType = "condition" // static + nodeSubflow nodeType = "subflow" // subflow + nodeStatic nodeType = "static" // static + nodeCondition nodeType = "condition" // static ) type innerNode struct { name string successors []*innerNode dependents []*innerNode - Typ NodeType + Typ nodeType ptr interface{} rw *sync.RWMutex state atomic.Int32 @@ -41,13 +42,23 @@ func (n *innerNode) JoinCounter() int { return n.joinCounter.Value() } +func (n *innerNode) setup() { + n.state.Store(kNodeStateIdle) + for _, dep := range n.dependents { + if dep.Typ == nodeCondition { + continue + } + + n.joinCounter.Increase() + } +} func (n *innerNode) drop() { // release every deps for _, node := range n.successors { - node.joinCounter.Decrease() + if n.Typ != nodeCondition { + node.joinCounter.Decrease() + } } - - n.g.joinCounter.Decrease() } // set dependency: V deps on N, V is input node diff --git a/profiler.go b/profiler.go index c1b9a31..d039b40 100644 --- a/profiler.go +++ b/profiler.go @@ -28,7 +28,7 @@ func (t *profiler) AddSpan(s *span) { } type attr struct { - typ NodeType + typ nodeType success bool // 0 for success, 1 for abnormal name string } @@ -46,7 +46,7 @@ func (s *span) String() string { func (t *profiler) draw(w io.Writer) error { for _, s := range t.spans { path := "" - if s.extra.typ == NodeStatic { + if s.extra.typ == nodeStatic { path = s.String() cur := s diff --git a/profiler_test.go b/profiler_test.go index c888a19..6744cf0 100644 --- a/profiler_test.go +++ b/profiler_test.go @@ -12,7 +12,7 @@ func TestProfilerAddSpan(t *testing.T) { profiler := newProfiler() span := &span{ extra: attr{ - typ: NodeStatic, + typ: nodeStatic, success: true, name: "test-span", }, @@ -34,7 +34,7 @@ func TestSpanString(t *testing.T) { now := time.Now() span := &span{ extra: attr{ - typ: NodeStatic, + typ: nodeStatic, success: true, name: "test-span", }, @@ -55,7 +55,7 @@ func TestProfilerDraw(t *testing.T) { now := time.Now() parentSpan := &span{ extra: attr{ - typ: NodeStatic, + typ: nodeStatic, success: true, name: "parent", }, @@ -65,7 +65,7 @@ func TestProfilerDraw(t *testing.T) { childSpan := &span{ extra: attr{ - typ: NodeStatic, + typ: nodeStatic, success: true, name: "child", }, diff --git a/taskflow_test.go b/taskflow_test.go index b9def04..74eb9c4 100644 --- a/taskflow_test.go +++ b/taskflow_test.go @@ -12,7 +12,7 @@ import ( "github.com/noneback/go-taskflow/utils" ) -var exector = gotaskflow.NewExecutor(10) +var executor = gotaskflow.NewExecutor(10) func TestTaskFlow(t *testing.T) { A, B, C := @@ -52,9 +52,9 @@ func TestTaskFlow(t *testing.T) { } }) - exector.Run(tf).Wait() + executor.Run(tf).Wait() fmt.Print("########### second times") - exector.Run(tf).Wait() + executor.Run(tf).Wait() } func TestSubflow(t *testing.T) { @@ -124,12 +124,12 @@ func TestSubflow(t *testing.T) { tf := gotaskflow.NewTaskFlow("G") tf.Push(A, B, C) tf.Push(A1, B1, C1, subflow, subflow2) - exector.Run(tf) - exector.Wait() + executor.Run(tf) + executor.Wait() if err := gotaskflow.Visualize(tf, os.Stdout); err != nil { log.Fatal(err) } - exector.Profile(os.Stdout) + executor.Profile(os.Stdout) // exector.Wait() // if err := tf.Visualize(os.Stdout); err != nil { @@ -155,7 +155,7 @@ func TestTaskflowPanic(t *testing.T) { tf := gotaskflow.NewTaskFlow("G") tf.Push(A, B, C) - exector.Run(tf).Wait() + executor.Run(tf).Wait() } func TestSubflowPanic(t *testing.T) { @@ -196,106 +196,186 @@ func TestSubflowPanic(t *testing.T) { tf := gotaskflow.NewTaskFlow("G") tf.Push(A, B, C) tf.Push(subflow) - exector.Run(tf) - exector.Wait() + executor.Run(tf) + executor.Wait() if err := gotaskflow.Visualize(tf, os.Stdout); err != nil { fmt.Errorf("%v", err) } - exector.Profile(os.Stdout) + executor.Profile(os.Stdout) } func TestTaskflowCondition(t *testing.T) { - A, B, C := - gotaskflow.NewTask("A", func() { - fmt.Println("A") - }), - gotaskflow.NewTask("B", func() { - fmt.Println("B") - }), - gotaskflow.NewTask("C", func() { - fmt.Println("C") + t.Run("normal", func(t *testing.T) { + A, B, C := + gotaskflow.NewTask("A", func() { + fmt.Println("A") + }), + gotaskflow.NewTask("B", func() { + fmt.Println("B") + }), + gotaskflow.NewTask("C", func() { + fmt.Println("C") + }) + A.Precede(B) + C.Precede(B) + tf := gotaskflow.NewTaskFlow("G") + tf.Push(A, B, C) + fail, success := gotaskflow.NewTask("failed", func() { + fmt.Println("Failed") + t.Fail() + }), gotaskflow.NewTask("success", func() { + fmt.Println("success") }) - A.Precede(B) - C.Precede(B) - tf := gotaskflow.NewTaskFlow("G") - tf.Push(A, B, C) - fail, success := gotaskflow.NewTask("failed", func() { - fmt.Println("Failed") - t.Fail() - }), gotaskflow.NewTask("success", func() { - fmt.Println("success") + + cond := gotaskflow.NewCondition("cond", func() uint { return 0 }) + B.Precede(cond) + cond.Precede(success, fail) + + suc := gotaskflow.NewSubflow("sub1", func(sf *gotaskflow.Subflow) { + A2, B2, C2 := + gotaskflow.NewTask("A2", func() { + fmt.Println("A2") + }), + gotaskflow.NewTask("B2", func() { + fmt.Println("B2") + }), + gotaskflow.NewTask("C2", func() { + fmt.Println("C2") + }) + sf.Push(A2, B2, C2) + A2.Precede(B2) + C2.Precede(B2) + }) + fs := gotaskflow.NewTask("fail_single", func() { + fmt.Println("it should be canceled") + }) + fail.Precede(fs, suc) + // success.Precede(suc) + tf.Push(cond, success, fail, fs, suc) + executor.Run(tf).Wait() + + if err := gotaskflow.Visualize(tf, os.Stdout); err != nil { + fmt.Errorf("%v", err) + } + executor.Profile(os.Stdout) }) - cond := gotaskflow.NewCondition("cond", func() uint { return 0 }) - B.Precede(cond) - cond.Precede(success, fail) + t.Run("start with condion node", func(t *testing.T) { + i := 0 + tf := gotaskflow.NewTaskFlow("G") - suc := gotaskflow.NewSubflow("sub1", func(sf *gotaskflow.Subflow) { - A2, B2, C2 := - gotaskflow.NewTask("A2", func() { - fmt.Println("A2") + cond := gotaskflow.NewCondition("cond", func() uint { + if i == 0 { + return 0 + } else { + return 1 + } + }) + + zero, one := gotaskflow.NewTask("zero", func() { + fmt.Println("zero") + }), gotaskflow.NewTask("one", func() { + fmt.Println("one") + }) + cond.Precede(zero, one) + + tf.Push(zero, one, cond) + executor.Run(tf).Wait() + + if err := gotaskflow.Visualize(tf, os.Stdout); err != nil { + log.Fatal(err) + } + executor.Profile(os.Stdout) + + }) + +} + +func TestTaskflowLoop(t *testing.T) { + // t.SkipNow() + t.Run("normal", func(t *testing.T) { + i := 0 + tf := gotaskflow.NewTaskFlow("G") + init, cond, body, back, done := + gotaskflow.NewTask("init", func() { + i = 0 + fmt.Println("i=0") }), - gotaskflow.NewTask("B2", func() { - fmt.Println("B2") + gotaskflow.NewCondition("while i < 5", func() uint { + if i < 5 { + return 0 + } else { + return 1 + } }), - gotaskflow.NewTask("C2", func() { - fmt.Println("C2") + gotaskflow.NewTask("i++", func() { + i += 1 + fmt.Println("i++ =", i) + }), + gotaskflow.NewCondition("back", func() uint { + fmt.Println("back") + return 0 + }), + gotaskflow.NewTask("done", func() { + fmt.Println("done") }) - sf.Push(A2, B2, C2) - A2.Precede(B2) - C2.Precede(B2) - }) - fs := gotaskflow.NewTask("fail_single", func() { - fmt.Println("it should be canceled") + + tf.Push(init, cond, body, back, done) + + init.Precede(cond) + cond.Precede(body, done) + body.Precede(back) + back.Precede(cond) + + executor.Run(tf).Wait() + if i < 5 { + t.Fail() + } + + if err := gotaskflow.Visualize(tf, os.Stdout); err != nil { + // log.Fatal(err) + } + executor.Profile(os.Stdout) }) - fail.Precede(fs, suc) - // success.Precede(suc) - tf.Push(cond, success, fail, fs, suc) - exector.Run(tf).Wait() - if err := gotaskflow.Visualize(tf, os.Stdout); err != nil { - fmt.Errorf("%v", err) - } - exector.Profile(os.Stdout) -} + t.Run("simple loop", func(t *testing.T) { + i := 0 + tf := gotaskflow.NewTaskFlow("G") + init := gotaskflow.NewTask("init", func() { + i = 0 + }) + cond := gotaskflow.NewCondition("cond", func() uint { + i++ + fmt.Println("i++ =", i) + if i > 2 { + return 0 + } else { + return 1 + } + }) -func TestTaskflowLoop(t *testing.T) { - A, B, C := - gotaskflow.NewTask("A", func() { - fmt.Println("A") - }), - gotaskflow.NewTask("B", func() { - fmt.Println("B") - }), - gotaskflow.NewTask("C", func() { - fmt.Println("C") + done := gotaskflow.NewTask("done", func() { + fmt.Println("done") }) - A.Precede(B) - C.Precede(B) - tf := gotaskflow.NewTaskFlow("G") - tf.Push(A, B, C) - zero := gotaskflow.NewTask("zero", func() { - fmt.Println("zero") - }) - counter := uint(0) - cond := gotaskflow.NewCondition("cond", func() uint { - counter += 1 - return counter % 3 - }) - B.Precede(cond) - cond.Precede(cond, cond, zero) - tf.Push(cond, zero) - exector.Run(tf).Wait() + init.Precede(cond) + cond.Precede(done, cond) - if err := gotaskflow.Visualize(tf, os.Stdout); err != nil { - fmt.Errorf("%v", err) - } - exector.Profile(os.Stdout) + tf.Push(done, cond, init) + executor.Run(tf).Wait() + if i <= 2 { + t.Fail() + } + + if err := gotaskflow.Visualize(tf, os.Stdout); err != nil { + // log.Fatal(err) + } + executor.Profile(os.Stdout) + }) } func TestTaskflowPriority(t *testing.T) { - exector := gotaskflow.NewExecutor(uint(2)) + executor := gotaskflow.NewExecutor(uint(2)) q := utils.NewQueue[byte]() tf := gotaskflow.NewTaskFlow("G") B, C := @@ -325,7 +405,7 @@ func TestTaskflowPriority(t *testing.T) { }).Priority(gotaskflow.LOW) tf.Push(B, C, suc) - exector.Run(tf).Wait() + executor.Run(tf).Wait() for _, val := range []byte{'C', 'B', 'b', 'c', 'a'} { real := q.PeakAndTake() diff --git a/utils/utils_test.go b/utils/utils_test.go index 6c5fe0c..af26601 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -5,12 +5,9 @@ import ( "testing" ) -// UnsafeToString 测试 func TestUnsafeToString(t *testing.T) { original := "Hello, World!" b := []byte(original) - - // 将字节切片转换为字符串 s := UnsafeToString(b) if s != original { @@ -18,10 +15,8 @@ func TestUnsafeToString(t *testing.T) { } } -// UnsafeToBytes 测试 func TestUnsafeToBytes(t *testing.T) { original := "Hello, World!" - // 将字符串转换为字节切片 b := UnsafeToBytes(original) if string(b) != original { @@ -29,37 +24,31 @@ func TestUnsafeToBytes(t *testing.T) { } } -// RC 结构体测试 func TestRC(t *testing.T) { rc := NewRC() - // 测试初始值 if rc.Value() != 0 { t.Errorf("Expected count to be 0, got %d", rc.Value()) } - // 测试增加计数 rc.Increase() if rc.Value() != 1 { t.Errorf("Expected count to be 1, got %d", rc.Value()) } - // 测试减少计数 rc.Decrease() if rc.Value() != 0 { t.Errorf("Expected count to be 0, got %d", rc.Value()) } - // 测试不能减少到负值 defer func() { if r := recover(); r == nil { t.Errorf("Expected panic when decreasing below zero, but did not") } }() - rc.Decrease() // 这里应该触发 panic + rc.Decrease() } -// 测试 Set 和负值 func TestSet(t *testing.T) { rc := NewRC() rc.Set(5) @@ -68,7 +57,6 @@ func TestSet(t *testing.T) { t.Errorf("Expected count to be 5, got %d", rc.Value()) } - // 测试负值情况 rc.Set(-1) if rc.Value() != -1 { @@ -79,7 +67,6 @@ func TestSet(t *testing.T) { func TestPanic(t *testing.T) { f := func() { defer func() { - // 使用 recover 捕获 panic if r := recover(); r != nil { fmt.Println("Recovered in causePanic:", r) } @@ -87,11 +74,6 @@ func TestPanic(t *testing.T) { }() fmt.Println("result") - // panic("Atest") } f() } - -// result -// Recovered in causePanic: Atest -// 1