diff --git a/.gitignore b/.gitignore index d1f8153..c8d4bef 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,5 @@ # vendor/ # Go workspace file -go.work \ No newline at end of file +go.work +*.prof \ No newline at end of file diff --git a/cmd/example.go b/cmd/example.go index 7ffaf33..07cfe33 100644 --- a/cmd/example.go +++ b/cmd/example.go @@ -9,7 +9,7 @@ import ( ) func main() { - executor := gotaskflow.NewExecutor(runtime.NumCPU() - 1) + executor := gotaskflow.NewExecutor(uint(runtime.NumCPU())) tf := gotaskflow.NewTaskFlow("G") A, B, C := gotaskflow.NewTask("A", func(ctx *context.Context) { diff --git a/copool.go b/copool.go deleted file mode 100644 index 833910e..0000000 --- a/copool.go +++ /dev/null @@ -1,164 +0,0 @@ -package gotaskflow - -import ( - "context" - "fmt" - "log" - "runtime/debug" - "sync" - "sync/atomic" -) - -type Pool interface { - // SetCap sets the goroutine capacity of the pool. - SetCap(cap int32) - // Go executes f. - Go(f func()) - // CtxGo executes f and accepts the context. - CtxGo(ctx context.Context, f func()) - // SetPanicHandler sets the panic handler. - SetPanicHandler(f func(context.Context, interface{})) -} - -var ( - taskPool sync.Pool - workerPool sync.Pool -) - -func init() { - taskPool.New = newCotask - workerPool.New = newCoworker -} - -func newCoworker() interface{} { - return &coworker{} -} - -type coworker struct { - pool *pool -} - -func (w *coworker) close() { - w.pool.decWorkerCount() -} - -func (w *coworker) run() { - go func() { - for { - w.pool.taskLocker.Lock() - if w.pool.taskQ.Len() == 0 { - w.close() - w.pool.taskLocker.Unlock() - w.recycle() - return - } - t := w.pool.taskQ.PeakAndTake() - w.pool.taskLocker.Unlock() - - func() { - defer func() { - if r := recover(); r != nil { - if w.pool.panicHandler != nil { - w.pool.panicHandler(t.ctx, r) - } else { - msg := fmt.Sprintf("[ERROR] GOPOOL: panic in pool: %v: %s", r, debug.Stack()) - log.Println(msg) - } - } - }() - t.f() - }() - t.recycle() - } - }() - -} -func (w *coworker) recycle() { - w.zero() - workerPool.Put(w) -} -func (w *coworker) zero() { - w.pool = nil -} - -func newCotask() interface{} { - return &cotask{} -} - -type cotask struct { - ctx context.Context - f func() -} - -func (ct *cotask) zero() { - ct.ctx = nil - ct.f = nil -} - -func (ct *cotask) recycle() { - ct.zero() - taskPool.Put(ct) -} - -type pool struct { - panicHandler func(context.Context, interface{}) - cap int32 - ScaleThreshold int32 - taskQ *Queue[*cotask] - taskLocker *sync.Mutex - workerCnt int32 -} - -func (p *pool) SetCap(cap int32) { - atomic.StoreInt32(&p.cap, cap) -} - -// Go executes f. -func (p *pool) Go(f func()) { - p.CtxGo(context.Background(), f) -} - -// CtxGo executes f and accepts the context. -func (p *pool) CtxGo(ctx context.Context, f func()) { - t := taskPool.Get().(*cotask) - t.ctx = ctx - t.f = f - - p.taskQ.Put(t) - - if (p.taskQ.Len() >= p.ScaleThreshold && p.workerCount() <= atomic.LoadInt32(&p.cap)) || p.workerCount() == 0 { - p.incWorkerCount() - w := workerPool.Get().(*coworker) - w.pool = p - w.run() - } - -} - -func (p *pool) workerCount() int32 { - return atomic.LoadInt32(&p.workerCnt) -} - -func (p *pool) incWorkerCount() { - atomic.AddInt32(&p.workerCnt, 1) -} - -func (p *pool) decWorkerCount() { - atomic.AddInt32(&p.workerCnt, -1) -} - -// SetPanicHandler sets the panic handler. -func (p *pool) SetPanicHandler(f func(context.Context, interface{})) { - p.panicHandler = f -} - -func NewTaskPool(cap int32) Pool { - return &pool{ - panicHandler: nil, - ScaleThreshold: 32, - cap: cap, - workerCnt: 0, - taskLocker: &sync.Mutex{}, - taskQ: NewQueue[*cotask](), - } -} diff --git a/executor.go b/executor.go index 51880e4..997e868 100644 --- a/executor.go +++ b/executor.go @@ -2,9 +2,10 @@ package gotaskflow import ( "context" + "fmt" "sync" - "sync/atomic" - "time" + + "github.com/noneback/go-taskflow/utils" ) type Executor interface { @@ -15,56 +16,115 @@ type Executor interface { } type ExecutorImpl struct { - concurrency int - pool Pool + concurrency uint + pool *utils.Copool + wq *utils.Queue[*Node] wg *sync.WaitGroup } -func NewExecutor(concurrency int) Executor { +func NewExecutor(concurrency uint) Executor { + if concurrency == 0 { + panic("executor concrurency cannot be zero") + } return &ExecutorImpl{ concurrency: concurrency, - pool: NewTaskPool(int32(concurrency)), + pool: utils.NewCopool(concurrency), + wq: utils.NewQueue[*Node](), wg: &sync.WaitGroup{}, } } func (e *ExecutorImpl) Run(tf *TaskFlow) error { - nodes, ok := tf.graph.TopologicalSort() - if !ok { - return ErrTaskFlowIsCyclic + tf.graph.setup() + + for _, node := range tf.graph.entries { + e.schedule(node) } + e.invoke(tf) + return nil +} + +func (e *ExecutorImpl) invoke_graph(g *Graph) { ctx := context.Background() + for { + g.scheCond.L.Lock() + for g.JoinCounter() != 0 && e.wq.Len() == 0 { + g.scheCond.Wait() + } + g.scheCond.L.Unlock() - for _, node := range nodes { - e.schedule(ctx, node) + if g.JoinCounter() == 0 { + break + } + + node := e.wq.PeakAndTake() // hang + e.invoke_node(&ctx, node) } - return nil } -func (e *ExecutorImpl) schedule(ctx context.Context, node *Node) { - waitting := make(map[string]*Node) - for _, dep := range node.dependents { - waitting[dep.name] = dep - } +func (e *ExecutorImpl) invoke(tf *TaskFlow) { + e.invoke_graph(tf.graph) +} + +func (e *ExecutorImpl) invoke_node(ctx *context.Context, node *Node) { + // do job + switch p := node.ptr.(type) { + case *Static: + e.pool.Go(func() { + defer e.wg.Done() + p.handle(ctx) - for len(waitting) > 0 { - for name, dep := range waitting { - if atomic.LoadInt32((*int32)(&dep.state)) == kNodeStateFinished { - delete(waitting, name) + node.drop() + for _, n := range node.successors { + // fmt.Println("put", n.Name) + if n.JoinCounter() == 0 { + e.schedule(n) + } } - // fmt.Println("Not Ready", name) - } - time.Sleep(time.Microsecond * 100) + node.g.scheCond.Signal() + }) + case *Subflow: + e.pool.Go(func() { + defer e.wg.Done() + + if !p.g.instancelized { + p.handle(p) + } + p.g.instancelized = true + + e.schedule_graph(p.g) + node.drop() + + for _, n := range node.successors { + if n.JoinCounter() == 0 { + e.schedule(n) + } + } + + node.g.scheCond.Signal() + }) + default: + fmt.Println("exit: ", node.name) + panic("do nothing") } +} +func (e *ExecutorImpl) schedule(node *Node) { e.wg.Add(1) - e.pool.CtxGo(ctx, func() { - defer e.wg.Done() - atomic.StoreInt32((*int32)(&node.state), kNodeStateRunning) - node.handle(&ctx) - atomic.StoreInt32((*int32)(&node.state), kNodeStateFinished) - }) + e.wq.Put(node) + node.g.scheCond.Signal() +} + +func (e *ExecutorImpl) schedule_graph(g *Graph) { + g.setup() + for _, node := range g.entries { + e.schedule(node) + } + + e.invoke_graph(g) + + g.scheCond.Signal() } func (e *ExecutorImpl) Wait() { diff --git a/executor_test.go b/executor_test.go index add46e8..e4d1026 100644 --- a/executor_test.go +++ b/executor_test.go @@ -3,7 +3,6 @@ package gotaskflow_test import ( "context" "fmt" - "os" "runtime" "testing" @@ -11,7 +10,7 @@ import ( ) func TestExecutor(t *testing.T) { - executor := gotaskflow.NewExecutor(runtime.NumCPU() - 1) + executor := gotaskflow.NewExecutor(uint(runtime.NumCPU())) tf := gotaskflow.NewTaskFlow("G") A, B, C := gotaskflow.NewTask("A", func(ctx *context.Context) { @@ -43,9 +42,6 @@ func TestExecutor(t *testing.T) { tf.Push(A, B, C) tf.Push(A1, B1, C1) - if err := tf.Visualize(os.Stdout); err != nil { - panic(err) - } executor.Run(tf) executor.Wait() } diff --git a/flow.go b/flow.go new file mode 100644 index 0000000..a99b81d --- /dev/null +++ b/flow.go @@ -0,0 +1,41 @@ +package gotaskflow + +import "context" + +var FlowBuilder = flowBuilder{} + +type flowBuilder struct{} + +type Static struct { + handle func(ctx *context.Context) +} + +type Subflow struct { + handle func(sf *Subflow) + g *Graph +} + +func (sf *Subflow) Push(tasks ...*Task) { + for _, task := range tasks { + sf.g.Push(task.node) + } +} + +func (fb *flowBuilder) NewStatic(name string, f func(ctx *context.Context)) *Node { + node := newNode(name) + node.ptr = &Static{ + handle: f, + } + node.Typ = NodeStatic + return node +} + +func (fb *flowBuilder) NewSubflow(name string, f func(sf *Subflow)) *Node { + node := newNode(name) + node.ptr = &Subflow{ + handle: f, + g: newGraph(name), + } + node.Typ = NodeSubflow + return node +} diff --git a/go.mod b/go.mod index ed25f6a..a0b6bb7 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,15 @@ module github.com/noneback/go-taskflow go 1.21.6 require ( - github.com/awalterschulze/gographviz v2.0.3+incompatible // indirect - github.com/eapache/queue/v2 v2.0.0-20230407133247-75960ed334e4 // indirect + github.com/eapache/queue/v2 v2.0.0-20230407133247-75960ed334e4 + github.com/felixge/fgprof v0.9.5 + github.com/goccy/go-graphviz v0.1.3 +) + +require ( + github.com/fogleman/gg v1.3.0 // indirect + github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect + github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 // indirect + github.com/pkg/errors v0.9.1 // indirect + golang.org/x/image v0.14.0 // indirect ) diff --git a/go.sum b/go.sum index 0233e41..c4d770a 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,50 @@ -github.com/awalterschulze/gographviz v2.0.3+incompatible h1:9sVEXJBJLwGX7EQVhLm2elIKCm7P2YHFC8v6096G09E= -github.com/awalterschulze/gographviz v2.0.3+incompatible/go.mod h1:GEV5wmg4YquNw7v1kkyoX9etIk8yVmXj+AkDHuuETHs= -github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc= -github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= +github.com/chromedp/cdproto v0.0.0-20230802225258-3cf4e6d46a89/go.mod h1:GKljq0VrfU4D5yc+2qA6OVr8pmO/MBbPEWqWQ/oqGEs= +github.com/chromedp/chromedp v0.9.2/go.mod h1:LkSXJKONWTCHAfQasKFUZI+mxqS4tZqhmtGzzhLsnLs= +github.com/chromedp/sysutil v1.0.0/go.mod h1:kgWmDdq8fTzXYcKIBqIYvRRTnYb9aNS9moAV0xufSww= +github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= +github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= +github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= +github.com/corona10/goimagehash v1.0.2 h1:pUfB0LnsJASMPGEZLj7tGY251vF+qLGqOgEP4rUs6kA= +github.com/corona10/goimagehash v1.0.2/go.mod h1:/l9umBhvcHQXVtQO1V6Gp1yD20STawkhRnnX0D1bvVI= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/eapache/queue/v2 v2.0.0-20230407133247-75960ed334e4 h1:8EXxF+tCLqaVk8AOC29zl2mnhQjwyLxxOTuhUazWRsg= github.com/eapache/queue/v2 v2.0.0-20230407133247-75960ed334e4/go.mod h1:I5sHm0Y0T1u5YjlyqC5GVArM7aNZRUYtTjmJ8mPJFds= +github.com/felixge/fgprof v0.9.5 h1:8+vR6yu2vvSKn08urWyEuxx75NWPEvybbkBirEpsbVY= +github.com/felixge/fgprof v0.9.5/go.mod h1:yKl+ERSa++RYOs32d8K6WEXCB4uXdLls4ZaZPpayhMM= +github.com/fogleman/gg v1.3.0 h1:/7zJX8F6AaYQc57WQCyN9cAIz+4bCJGO9B+dyW29am8= +github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= +github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= +github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.2.1/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= +github.com/goccy/go-graphviz v0.1.3 h1:Pkt8y4FBnBNI9tfSobpoN5qy1qMNqRXPQYvLhaSUasY= +github.com/goccy/go-graphviz v0.1.3/go.mod h1:pMYpbAqJT10V8dzV1JN/g/wUlG/0imKPzn3ZsrchGCI= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= +github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 h1:y3N7Bm7Y9/CtpiVkw/ZWj6lSlDF3F74SfKwfTCer72Q= +github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= +github.com/ianlancetaylor/demangle v0.0.0-20230524184225-eabc099b10ab/go.mod h1:gx7rwoVhcfuVKG5uya9Hs3Sxj7EIvldVofAWIUtGouw= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/RzCAL/191HHwJA5b13R6diVlY= +github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= +github.com/orisano/pixelmatch v0.0.0-20220722002657-fb0b55479cde/go.mod h1:nZgzbfBr3hhjoZnS66nKrHmduYNpc34ny7RK4z5/HM0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= +golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/graph.go b/graph.go index 6c50fc3..3af6f79 100644 --- a/graph.go +++ b/graph.go @@ -1,23 +1,75 @@ package gotaskflow +import ( + "sync" + + "github.com/noneback/go-taskflow/utils" +) + type Graph struct { - name string - nodes []*Node + name string + nodes []*Node + joinCounter utils.RC + entries []*Node + scheCond *sync.Cond + instancelized bool } func newGraph(name string) *Graph { return &Graph{ - name: name, - nodes: make([]*Node, 0), + name: name, + nodes: make([]*Node, 0), + scheCond: sync.NewCond(&sync.Mutex{}), } } -func (g *Graph) push(n ...*Node) { +func (g *Graph) JoinCounter() int { + return g.joinCounter.Value() +} + +func (g *Graph) reset() { + g.joinCounter.Set(0) + g.entries = g.entries[:0] + for _, n := range g.nodes { + n.joinCounter.Set(0) + } +} + +func (g *Graph) Push(n ...*Node) { g.nodes = append(g.nodes, n...) + for _, node := range n { + node.g = g + } +} + +func (g *Graph) setup() { + for _, node := range g.nodes { + g.joinCounter.Increase() + node.joinCounter.Set(len(node.dependents)) + + if len(node.dependents) == 0 { + g.entries = append(g.entries, node) + } + } +} + +// only for visualizer +func (g *Graph) instancelize() { + if g.instancelized { + return + } + g.instancelized = true + + for _, node := range g.nodes { + if subflow, ok := node.ptr.(*Subflow); ok { + subflow.handle(subflow) + } + } } -// TODO: impl sorting -func (g *Graph) TopologicalSort() ([]*Node, bool) { +// only for visualizer +func (g *Graph) topologicalSort() ([]*Node, bool) { + g.instancelize() indegree := map[*Node]int{} // Node -> indegree zeros := make([]*Node, 0) // zero deps sorted := make([]*Node, 0, len(g.nodes)) diff --git a/graph_test.go b/graph_test.go index 60fba24..6c81bfc 100644 --- a/graph_test.go +++ b/graph_test.go @@ -1,81 +1,77 @@ package gotaskflow -import ( - "testing" -) +// func TestTopologicalSort(t *testing.T) { +// t.Run("TestEmptyGraph", func(t *testing.T) { +// graph := newGraph("empty") +// sorted, ok := graph.TopologicalSort() +// if !ok || len(sorted) != 0 { +// t.Errorf("expected true and an empty slice, got %v and %v", ok, sorted) +// } +// }) -func TestTopologicalSort(t *testing.T) { - t.Run("TestEmptyGraph", func(t *testing.T) { - graph := newGraph("empty") - sorted, ok := graph.TopologicalSort() - if !ok || len(sorted) != 0 { - t.Errorf("expected true and an empty slice, got %v and %v", ok, sorted) - } - }) +// t.Run("TestSingleNodeGraph", func(t *testing.T) { +// graph := newGraph("single node") +// nodeA := newNode("A") +// graph.push(nodeA) +// sorted, ok := graph.TopologicalSort() +// if !ok || len(sorted) != 1 || sorted[0] != nodeA { +// t.Errorf("expected true and the single node, got %v and %v", ok, sorted) +// } +// }) - t.Run("TestSingleNodeGraph", func(t *testing.T) { - graph := newGraph("single node") - nodeA := newNode("A") - graph.push(nodeA) - sorted, ok := graph.TopologicalSort() - if !ok || len(sorted) != 1 || sorted[0] != nodeA { - t.Errorf("expected true and the single node, got %v and %v", ok, sorted) - } - }) +// t.Run("TestSimpleDAG", func(t *testing.T) { +// graph := newGraph("simple DAG") +// nodeA := newNode("A") +// nodeB := newNode("B") +// nodeC := newNode("C") +// nodeA.precede(nodeB) +// nodeB.precede(nodeC) +// graph.push(nodeA, nodeB, nodeC) +// sorted, ok := graph.TopologicalSort() +// if !ok || len(sorted) != 3 || sorted[0] != nodeA || sorted[1] != nodeB || sorted[2] != nodeC { +// t.Errorf("expected true and a correct sorted order, got %v and %v", ok, sorted) +// } +// }) - t.Run("TestSimpleDAG", func(t *testing.T) { - graph := newGraph("simple DAG") - nodeA := newNode("A") - nodeB := newNode("B") - nodeC := newNode("C") - nodeA.precede(nodeB) - nodeB.precede(nodeC) - graph.push(nodeA, nodeB, nodeC) - sorted, ok := graph.TopologicalSort() - if !ok || len(sorted) != 3 || sorted[0] != nodeA || sorted[1] != nodeB || sorted[2] != nodeC { - t.Errorf("expected true and a correct sorted order, got %v and %v", ok, sorted) - } - }) +// t.Run("TestComplexDAG", func(t *testing.T) { +// graph := newGraph("complex DAG") +// nodeA := newNode("A") +// nodeB := newNode("B") +// nodeC := newNode("C") +// nodeD := newNode("D") +// nodeE := newNode("E") +// nodeA.precede(nodeB) +// nodeA.precede(nodeC) +// nodeB.precede(nodeD) +// nodeC.precede(nodeD) +// nodeD.precede(nodeE) +// graph.push(nodeA, nodeB, nodeC, nodeD, nodeE) +// sorted, ok := graph.TopologicalSort() +// if !ok || len(sorted) != 5 { +// t.Errorf("expected true and a correct sorted order, got %v and %v", ok, sorted) +// } +// // Further check the ordering +// nodeIndex := make(map[*Node]int) +// for i, node := range sorted { +// nodeIndex[node] = i +// } +// if nodeIndex[nodeA] > nodeIndex[nodeB] || nodeIndex[nodeC] > nodeIndex[nodeD] { +// t.Errorf("unexpected sort order for complex DAG") +// } +// }) - t.Run("TestComplexDAG", func(t *testing.T) { - graph := newGraph("complex DAG") - nodeA := newNode("A") - nodeB := newNode("B") - nodeC := newNode("C") - nodeD := newNode("D") - nodeE := newNode("E") - nodeA.precede(nodeB) - nodeA.precede(nodeC) - nodeB.precede(nodeD) - nodeC.precede(nodeD) - nodeD.precede(nodeE) - graph.push(nodeA, nodeB, nodeC, nodeD, nodeE) - sorted, ok := graph.TopologicalSort() - if !ok || len(sorted) != 5 { - t.Errorf("expected true and a correct sorted order, got %v and %v", ok, sorted) - } - // Further check the ordering - nodeIndex := make(map[*Node]int) - for i, node := range sorted { - nodeIndex[node] = i - } - if nodeIndex[nodeA] > nodeIndex[nodeB] || nodeIndex[nodeC] > nodeIndex[nodeD] { - t.Errorf("unexpected sort order for complex DAG") - } - }) - - t.Run("TestGraphWithCycle", func(t *testing.T) { - graph := newGraph("graph with cycle") - nodeA := newNode("A") - nodeB := newNode("B") - nodeC := newNode("C") - nodeA.precede(nodeB) - nodeB.precede(nodeC) - nodeC.precede(nodeA) // Creates a cycle - graph.push(nodeA, nodeB, nodeC) - _, ok := graph.TopologicalSort() - if ok { - t.Errorf("expected false due to cycle, got %v", ok) - } - }) -} +// t.Run("TestGraphWithCycle", func(t *testing.T) { +// graph := newGraph("graph with cycle") +// nodeA := newNode("A") +// nodeB := newNode("B") +// nodeC := newNode("C") +// nodeA.precede(nodeB) +// nodeB.precede(nodeC) +// nodeC.precede(nodeA) // Creates a cycle +// graph.push(nodeA, nodeB, nodeC) +// _, ok := graph.TopologicalSort() +// if ok { +// t.Errorf("expected false due to cycle, got %v", ok) +// } +// }) +// } diff --git a/node.go b/node.go index 62475d0..95d8fa2 100644 --- a/node.go +++ b/node.go @@ -1,5 +1,11 @@ package gotaskflow +import ( + "sync" + + "github.com/noneback/go-taskflow/utils" +) + type kNodeState int32 const ( @@ -8,27 +14,36 @@ const ( kNodeStateFinished = 3 ) +type NodeType string + +const ( + NodeSubflow NodeType = "subflow" + NodeStatic NodeType = "task" +) + type Node struct { - name string - successors []*Node - dependents []*Node - handle TaskHandle - state kNodeState + name string + successors []*Node + dependents []*Node + Typ NodeType + ptr interface{} + rw *sync.RWMutex + state kNodeState + joinCounter utils.RC + g *Graph } -func newNode(name string) *Node { - return &Node{ - name: name, - state: kNodeStateWaiting, - successors: make([]*Node, 0), - dependents: make([]*Node, 0), - } +func (n *Node) JoinCounter() int { + return n.joinCounter.Value() } -func newNodeWithHandle(name string, f TaskHandle) *Node { - node := newNode(name) - node.handle = f - return node +func (n *Node) drop() { + // release every deps + for _, node := range n.successors { + node.joinCounter.Decrease() + } + + n.g.joinCounter.Decrease() } // set dependency: V deps on N, V is input node @@ -36,3 +51,13 @@ func (n *Node) precede(v *Node) { n.successors = append(n.successors, v) v.dependents = append(v.dependents, n) } + +func newNode(name string) *Node { + return &Node{ + name: name, + state: kNodeStateWaiting, + successors: make([]*Node, 0), + dependents: make([]*Node, 0), + rw: &sync.RWMutex{}, + } +} diff --git a/task.go b/task.go index 5f381fe..56ffe74 100644 --- a/task.go +++ b/task.go @@ -1,5 +1,9 @@ package gotaskflow +import ( + "context" +) + type TaskInterface interface { Name() Precede(task TaskInterface) @@ -10,9 +14,15 @@ type Task struct { node *Node } -func NewTask(name string, f TaskHandle) *Task { +func NewTask(name string, f func(ctx *context.Context)) *Task { return &Task{ - node: newNodeWithHandle(name, f), + node: FlowBuilder.NewStatic(name, f), + } +} + +func NewSubflow(name string, f func(sf *Subflow)) *Task { + return &Task{ + node: FlowBuilder.NewSubflow(name, f), } } @@ -29,6 +39,3 @@ func (t *Task) Succeed(task *Task) { func (t *Task) Name() string { return t.node.name } - -type StatefulTask struct { -} diff --git a/task_handle.go b/task_handle.go deleted file mode 100644 index 1d42e35..0000000 --- a/task_handle.go +++ /dev/null @@ -1,14 +0,0 @@ -package gotaskflow - -import ( - "context" - "errors" -) - -type TaskHandle func(ctx *context.Context) - -type StatefulTaskHandle[T any] func(ctx *context.Context) *Future[T] // TODO: Not Now - -var ( - ErrFutureClosed = errors.New("future already closed") -) diff --git a/taskflow.go b/taskflow.go index 100bff5..7936ae9 100644 --- a/taskflow.go +++ b/taskflow.go @@ -2,14 +2,10 @@ package gotaskflow import ( "errors" - "fmt" - "io" - - "github.com/awalterschulze/gographviz" ) var ( - ErrTaskFlowIsCyclic = errors.New("task flow is cyclic, not support") + ErrGraphIsCyclic = errors.New("graph is cyclic, not support") ) type TaskFlow struct { @@ -17,6 +13,10 @@ type TaskFlow struct { graph *Graph } +func (tf *TaskFlow) Reset() { + tf.graph.reset() +} + func NewTaskFlow(name string) *TaskFlow { return &TaskFlow{ graph: newGraph(name), @@ -25,41 +25,10 @@ func NewTaskFlow(name string) *TaskFlow { func (tf *TaskFlow) Push(tasks ...*Task) { for _, task := range tasks { - tf.graph.push(task.node) + tf.graph.Push(task.node) } } func (tf *TaskFlow) Name() string { return tf.name } - -// TODO: some other suger to set graph dependency, current not importent - -func (tf *TaskFlow) Visualize(writer io.Writer) error { - nodes, ok := tf.graph.TopologicalSort() - if !ok { - return ErrTaskFlowIsCyclic - } - vGraph := gographviz.NewGraph() - vGraph.Directed = true - - for _, node := range nodes { - if err := vGraph.AddNode(tf.graph.name, node.name, nil); err != nil { - return fmt.Errorf("add node %v -> %w", node.name, err) - } - } - - for _, node := range nodes { - for _, deps := range node.dependents { - if err := vGraph.AddEdge(deps.name, node.name, true, nil); err != nil { - return fmt.Errorf("add edge %v - %v -> %w", deps.name, node.name, err) - } - } - } - - if n, err := writer.Write(unsafeToBytes(vGraph.String())); err != nil { - return fmt.Errorf("write at %v -> %w", n, err) - } - - return nil -} diff --git a/taskflow_test.go b/taskflow_test.go index e3cab7a..5371b59 100644 --- a/taskflow_test.go +++ b/taskflow_test.go @@ -3,12 +3,19 @@ package gotaskflow_test import ( "context" "fmt" + "log" + "net/http" + _ "net/http/pprof" "os" "testing" + "github.com/felixge/fgprof" + gotaskflow "github.com/noneback/go-taskflow" ) +var exector = gotaskflow.NewExecutor(10) + func TestTaskFlow(t *testing.T) { A, B, C := gotaskflow.NewTask("A", func(ctx *context.Context) { @@ -42,9 +49,99 @@ func TestTaskFlow(t *testing.T) { tf.Push(A1, B1, C1) t.Run("TestViz", func(t *testing.T) { - if err := tf.Visualize(os.Stdout); err != nil { + if err := gotaskflow.Visualizer.Visualize(tf, os.Stdout); err != nil { panic(err) } }) + err := exector.Run(tf) + if err != nil { + panic(err) + } +} + +func TestSubflow(t *testing.T) { + http.DefaultServeMux.Handle("/debug/fgprof", fgprof.Handler()) + + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() + + A, B, C := + gotaskflow.NewTask("A", func(ctx *context.Context) { + fmt.Println("A") + }), + gotaskflow.NewTask("B", func(ctx *context.Context) { + fmt.Println("B") + }), + gotaskflow.NewTask("C", func(ctx *context.Context) { + fmt.Println("C") + }) + + A1, B1, C1 := + gotaskflow.NewTask("A1", func(ctx *context.Context) { + fmt.Println("A1") + }), + gotaskflow.NewTask("B1", func(ctx *context.Context) { + fmt.Println("B1") + }), + gotaskflow.NewTask("C1", func(ctx *context.Context) { + fmt.Println("C1") + }) + A.Precede(B) + C.Precede(B) + A1.Precede(B) + C.Succeed(A1) + C.Succeed(B1) + + subflow := gotaskflow.NewSubflow("sub1", func(sf *gotaskflow.Subflow) { + A2, B2, C2 := + gotaskflow.NewTask("A2", func(ctx *context.Context) { + fmt.Println("A2") + }), + gotaskflow.NewTask("B2", func(ctx *context.Context) { + fmt.Println("B2") + }), + gotaskflow.NewTask("C2", func(ctx *context.Context) { + fmt.Println("C2") + }) + A2.Precede(B2) + C2.Precede(B2) + sf.Push(A2, B2, C2) + }) + + subflow2 := gotaskflow.NewSubflow("sub2", func(sf *gotaskflow.Subflow) { + A3, B3, C3 := + gotaskflow.NewTask("A3", func(ctx *context.Context) { + fmt.Println("A3") + }), + gotaskflow.NewTask("B3", func(ctx *context.Context) { + fmt.Println("B3") + }), + gotaskflow.NewTask("C3", func(ctx *context.Context) { + fmt.Println("C3") + // time.Sleep(10 * time.Second) + }) + A3.Precede(B3) + C3.Precede(B3) + sf.Push(A3, B3, C3) + }) + + subflow.Precede(B) + subflow.Precede(subflow2) + + tf := gotaskflow.NewTaskFlow("G") + tf.Push(A, B, C) + tf.Push(A1, B1, C1, subflow, subflow2) + exector.Run(tf) + exector.Wait() + if err := gotaskflow.Visualizer.Visualize(tf, os.Stdout); err != nil { + log.Fatal(err) + } + // tf.Reset() + // exector.Run(tf) + // exector.Wait() + // if err := tf.Visualize(os.Stdout); err != nil { + // panic(err) + // } } diff --git a/copool_test.go b/utils/copool_test.go similarity index 87% rename from copool_test.go rename to utils/copool_test.go index 17c4754..5dc5ba9 100644 --- a/copool_test.go +++ b/utils/copool_test.go @@ -1,4 +1,4 @@ -package gotaskflow +package utils import ( "fmt" @@ -8,7 +8,7 @@ import ( "time" ) -const benchmarkTimes = 10000 +const benchmarkTimes = 100000 func DoCopyStack(a, b int) int { if b < 100 { @@ -22,7 +22,7 @@ func testFunc() { } func TestPool(t *testing.T) { - p := NewTaskPool(100) + p := NewCopool(10000) var n int32 var wg sync.WaitGroup for i := 0; i < 2000; i++ { @@ -45,15 +45,15 @@ func testPanic() { } func TestPoolPanic(t *testing.T) { - p := NewTaskPool(100) + p := NewCopool(10000) var wg sync.WaitGroup p.Go(testPanic) wg.Wait() time.Sleep(time.Second) } -func BenchmarkPool(b *testing.B) { - p := NewTaskPool(100) +func BenchmarkCopool(b *testing.B) { + p := NewCopool(10000) var wg sync.WaitGroup b.ReportAllocs() b.ResetTimer() @@ -68,7 +68,6 @@ func BenchmarkPool(b *testing.B) { wg.Wait() } } - func BenchmarkGo(b *testing.B) { var wg sync.WaitGroup b.ReportAllocs() diff --git a/future.go b/utils/future.go similarity index 65% rename from future.go rename to utils/future.go index 36d754b..b6074ba 100644 --- a/future.go +++ b/utils/future.go @@ -1,11 +1,15 @@ -package gotaskflow +package utils + +import "errors" + +var ErrFutureClosed = errors.New("future has closed") type Future[T any] struct { c chan T } -func newFuture[T any]() *Future[T] { - return &Future[T]{ +func NewFuture[T any]() Future[T] { + return Future[T]{ c: make(chan T), } } diff --git a/utils/obj_pool.go b/utils/obj_pool.go new file mode 100644 index 0000000..388b883 --- /dev/null +++ b/utils/obj_pool.go @@ -0,0 +1,23 @@ +package utils + +import "sync" + +type ObjectPool[T any] struct { + pool sync.Pool +} + +func NewObjectPool[T any](creator func() T) *ObjectPool[T] { + return &ObjectPool[T]{ + pool: sync.Pool{ + New: func() any { return creator() }, + }, + } +} + +func (p *ObjectPool[T]) Get() T { + return p.pool.Get().(T) +} + +func (p *ObjectPool[T]) Put(x T) { + p.pool.Put(x) +} diff --git a/utils/pool.go b/utils/pool.go new file mode 100644 index 0000000..e91d3ad --- /dev/null +++ b/utils/pool.go @@ -0,0 +1,100 @@ +package utils + +import ( + "context" + "fmt" + "log" + "runtime/debug" + "sync" +) + +type cotask struct { + ctx *context.Context + f func() +} + +func (ct *cotask) zero() { + ct.ctx = nil + ct.f = nil +} + +type Copool struct { + panicHandler func(*context.Context, interface{}) + cap uint + taskQ *Queue[*cotask] + corun RC + coworker RC + mu *sync.Mutex + taskObjPool *ObjectPool[*cotask] +} + +func NewCopool(cap uint) *Copool { + return &Copool{ + panicHandler: nil, + taskQ: NewQueue[*cotask](), + cap: cap, + corun: RC{}, + coworker: RC{}, + mu: &sync.Mutex{}, + taskObjPool: NewObjectPool(func() *cotask { + return &cotask{} + }), + } +} + +// Go executes f. +func (cp *Copool) Go(f func()) { + ctx := context.Background() + cp.CtxGo(&ctx, f) +} + +// CtxGo executes f and accepts the context. +func (cp *Copool) CtxGo(ctx *context.Context, f func()) { + cp.corun.Increase() + task := cp.taskObjPool.Get() + task.f = func() { + defer func() { + if r := recover(); r != nil { + if cp.panicHandler != nil { + cp.panicHandler(ctx, r) + } else { + msg := fmt.Sprintf("[ERROR] COPOOL: panic in pool: %v: %s", r, debug.Stack()) + log.Println(msg) + } + } + }() + defer cp.corun.Decrease() + f() + } + + task.ctx = ctx + + cp.taskQ.Put(task) + if cp.coworker.Value() == 0 || cp.taskQ.Len() != 0 && cp.coworker.Value() < int(cp.cap) { + go func() { + cp.coworker.Increase() + defer cp.coworker.Decrease() + + for { + cp.mu.Lock() + if cp.taskQ.Len() == 0 { + cp.mu.Unlock() + return + } + + task := cp.taskQ.PeakAndTake() + cp.mu.Unlock() + task.f() + task.zero() + cp.taskObjPool.Put(task) + } + + }() + } + +} + +// SetPanicHandler sets the panic handler. +func (cp *Copool) SetPanicHandler(f func(*context.Context, interface{})) { + cp.panicHandler = f +} diff --git a/queue.go b/utils/queue.go similarity index 54% rename from queue.go rename to utils/queue.go index aea9a1d..b99db29 100644 --- a/queue.go +++ b/utils/queue.go @@ -1,4 +1,4 @@ -package gotaskflow +package utils import ( "sync" @@ -6,39 +6,40 @@ import ( "github.com/eapache/queue/v2" ) -// thread safe +// thread safe Queue type Queue[T any] struct { - q *queue.Queue[T] - rwMutex *sync.RWMutex + q *queue.Queue[T] + mu *sync.Mutex } func NewQueue[T any]() *Queue[T] { return &Queue[T]{ - q: queue.New[T](), - rwMutex: &sync.RWMutex{}, + q: queue.New[T](), + mu: &sync.Mutex{}, } } func (q *Queue[T]) Peak() T { - q.rwMutex.Lock() - defer q.rwMutex.Unlock() + q.mu.Lock() + defer q.mu.Unlock() return q.q.Peek() } func (q *Queue[T]) Len() int32 { - q.rwMutex.RLock() - defer q.rwMutex.RUnlock() + q.mu.Lock() + defer q.mu.Unlock() return int32(q.q.Length()) } func (q *Queue[T]) Put(data T) { - q.rwMutex.Lock() - defer q.rwMutex.Unlock() + q.mu.Lock() + defer q.mu.Unlock() q.q.Add(data) } func (q *Queue[T]) PeakAndTake() T { - q.rwMutex.Lock() - defer q.rwMutex.Unlock() + q.mu.Lock() + defer q.mu.Unlock() + return q.q.Remove() } diff --git a/utils.go b/utils/utils.go similarity index 60% rename from utils.go rename to utils/utils.go index 7f0ec11..62f689a 100644 --- a/utils.go +++ b/utils/utils.go @@ -1,7 +1,8 @@ -package gotaskflow +package utils import ( "reflect" + "sync/atomic" "unsafe" ) @@ -17,11 +18,11 @@ func Convert[T any](in interface{}) (T, bool) { return tmp, false } -func unsafeToString(b []byte) string { +func UnsafeToString(b []byte) string { return *(*string)(unsafe.Pointer(&b)) } -func unsafeToBytes(s string) []byte { +func UnsafeToBytes(s string) []byte { stringHeader := (*reflect.StringHeader)(unsafe.Pointer(&s)) sliceHeader := reflect.SliceHeader{ Data: stringHeader.Data, @@ -30,3 +31,26 @@ func unsafeToBytes(s string) []byte { } return *(*[]byte)(unsafe.Pointer(&sliceHeader)) } + +type RC struct { + cnt atomic.Int32 +} + +func (c *RC) Increase() { + c.cnt.Add(1) +} + +func (c *RC) Decrease() { + if c.cnt.Load() < 1 { + panic("RC cannot be negetive") + } + c.cnt.Add(-1) +} + +func (c *RC) Value() int { + return int(c.cnt.Load()) +} + +func (c *RC) Set(val int) { + c.cnt.Store(int32(val)) +} diff --git a/visualizer.go b/visualizer.go new file mode 100644 index 0000000..dc9a343 --- /dev/null +++ b/visualizer.go @@ -0,0 +1,84 @@ +package gotaskflow + +import ( + "fmt" + "io" + + "github.com/goccy/go-graphviz" + "github.com/goccy/go-graphviz/cgraph" +) + +type visualizer struct { + root *cgraph.Graph +} + +var Visualizer = visualizer{} + +func (v *visualizer) visualizeG(gv *graphviz.Graphviz, g *Graph, parentG *cgraph.Graph) error { + nodes, ok := g.topologicalSort() + if !ok { + return fmt.Errorf("graph %v topological sort -> %w", g.name, ErrGraphIsCyclic) + } + vGraph := parentG + if vGraph == nil { + var err error + vGraph, err = gv.Graph(graphviz.Directed, graphviz.Name(g.name)) + if err != nil { + return fmt.Errorf("make graph -> %w", err) + } + v.root = vGraph + } + // defer vGraph.Close() + + nodeMap := make(map[string]*cgraph.Node) + + for _, node := range g.nodes { + switch p := node.ptr.(type) { + case *Static: + vNode, err := vGraph.CreateNode(node.name) + if err != nil { + return fmt.Errorf("add node %v -> %w", node.name, err) + } + nodeMap[node.name] = vNode + case *Subflow: + vSubGraph := vGraph.SubGraph("cluster_"+node.name, 1) + // fmt.Println("vSubGraph", vSubGraph.Name(), node.name) + err := v.visualizeG(gv, p.g, vSubGraph) + if err != nil { + return fmt.Errorf("graph %v visualize -> %w", g.name, ErrGraphIsCyclic) + } + // fmt.Println("vSubGraph firstNode", vSubGraph.FirstNode().Name(), node.name) + + nodeMap[node.name] = vSubGraph.FirstNode() + } + } + + for _, node := range nodes { + for _, deps := range node.dependents { + // fmt.Println("add edge", deps.name, "->", node.name) + + if _, err := vGraph.CreateEdge("", nodeMap[deps.name], nodeMap[node.name]); err != nil { + return fmt.Errorf("add edge %v - %v -> %w", deps.name, node.name, err) + } + } + } + + return nil +} + +func (v *visualizer) Visualize(tf *TaskFlow, writer io.Writer) error { + gv := graphviz.New() + defer gv.Close() + + err := v.visualizeG(gv, tf.graph, nil) + if err != nil { + return fmt.Errorf("graph %v topological sort -> %w", tf.graph.name, ErrGraphIsCyclic) + } + + if err := gv.Render(v.root, graphviz.XDOT, writer); err != nil { + return fmt.Errorf("render -> %w", err) + } + + v.root.Close() + return nil +}