From 4b87d5ce05b905b5bba71cf7609def5ccceb553e Mon Sep 17 00:00:00 2001 From: googs1025 Date: Wed, 2 Aug 2023 23:36:36 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E8=B6=85=E6=97=B6?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E4=B8=8E=E5=9B=9E=E8=B0=83=E6=96=B9=E6=B3=95?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 69 ++++++++------- example/README.md | 6 ++ example/example1_test.go | 10 ++- example/example2_test.go | 23 +++-- example/http_example/cmd/cmd.go | 2 +- example/http_example/pkg/server/model/task.go | 6 +- example/scheduler/scheduler_test.go | 4 +- pkg/workerpool/option.go | 27 ++++++ pkg/workerpool/pool.go | 65 ++++++++++---- pkg/workerpool/task.go | 21 +++-- pkg/workerpool/worker.go | 85 ++++++++++++++++--- 11 files changed, 235 insertions(+), 83 deletions(-) create mode 100644 example/README.md create mode 100644 pkg/workerpool/option.go diff --git a/README.md b/README.md index bf5ef94..3df39f3 100644 --- a/README.md +++ b/README.md @@ -14,26 +14,30 @@ func TestTaskPool1(t *testing.T) { - // 建立一个工作池 - // input:池数量 - pool := workerpool.NewPool(5) - + pool := workerpool.NewPool(5, workerpool.WithTimeout(1), workerpool.WithErrorCallback(func(err error) { + fmt.Println("WithErrorCallback") + if err != nil { + panic(err) + } + }), workerpool.WithResultCallback(func(i interface{}) { + fmt.Println("result: ", i) + })) // 需要处理的任务 - tt := func(data interface{}) error { + tt := func(data interface{}) (interface{}, error) { taskID := data.(int) // 业务逻辑 time.Sleep(100 * time.Millisecond) klog.Info("Task ", taskID, " processed") - return nil + return nil, nil } // 准备多个个任务 for i := 1; i <= 1000; i++ { // 需要做的任务 - task := workerpool.NewTask(tt, i) + task := workerpool.NewTaskInstance(fmt.Sprintf("task-%v", i), i, tt) // 所有的任务放入全局队列中 pool.AddGlobalQueue(task) @@ -59,51 +63,56 @@ func TestTaskPool2(t *testing.T) { // 建立一个池, // input:池数量 - pool := workerpool.NewPool(5) - + //pool := workerpool.NewPool(5) + pool := workerpool.NewPool(5, workerpool.WithTimeout(1), workerpool.WithErrorCallback(func(err error) { + if err != nil { + panic(err) + } + }), workerpool.WithResultCallback(func(i interface{}) { + fmt.Println("result: ", i) + })) + // 准备100个任务 for i := 1; i <= 100; i++ { - // 需要做的任务 - task := workerpool.NewTask(func(data interface{}) error { - taskID := data.(int) - - /* - 业务逻辑 - */ - time.Sleep(100 * time.Millisecond) - klog.Info("Task ", taskID, " processed") - return nil - }, i) + // 需要做的任务 + task := workerpool.NewTaskInstance(fmt.Sprintf("task-%v", i), i, func(data interface{}) (interface{}, error) { + taskID := data.(int) + /* + 业务逻辑 + */ + time.Sleep(100 * time.Millisecond) + klog.Info("Task ", taskID, " processed") + return nil, nil + }) + // 所有的任务放入list中 pool.AddGlobalQueue(task) } - // 启动在后台等待执行 go pool.RunBackground() - for { taskID := rand.Intn(100) + 20 - - // 模拟一个退出条件 + + //// 模拟一个退出条件 if taskID%7 == 0 { klog.Info("taskID: ", taskID, "pool stop!") pool.StopBackground() break } - + time.Sleep(time.Duration(rand.Intn(5)) * time.Second) // 模拟后续加入pool - task := workerpool.NewTask(func(data interface{}) error { + task := workerpool.NewTaskInstance(fmt.Sprintf("task-%v", taskID), taskID, func(data interface{}) (interface{}, error) { taskID := data.(int) - time.Sleep(100 * time.Millisecond) + time.Sleep(3 * time.Second) klog.Info("Task ", taskID, " processed") - return nil - }, taskID) - + return nil, nil + }) + pool.AddTask(task) } diff --git a/example/README.md b/example/README.md new file mode 100644 index 0000000..2f22b1a --- /dev/null +++ b/example/README.md @@ -0,0 +1,6 @@ +### 示例: + +- example1_test.go:同步等待运行的协程池示例 +- example2_test.go:异步等待运行的协程池示例 +- scheduler包:简易型调度器 +- http_example包:模拟httpServer暴露调度接口示例 diff --git a/example/example1_test.go b/example/example1_test.go index e1e2fbc..a1cba3a 100644 --- a/example/example1_test.go +++ b/example/example1_test.go @@ -20,16 +20,18 @@ func TestTaskPool1(t *testing.T) { // 建立一个工作池 // input:池数量 - pool := workerpool.NewPool(5) + pool := workerpool.NewPool(5, workerpool.WithTimeout(1), workerpool.WithResultCallback(func(i interface{}) { + fmt.Println("result: ", i) + })) // 需要处理的任务 - tt := func(data interface{}) error { + tt := func(data interface{}) (interface{}, error) { taskID := data.(int) // 业务逻辑 time.Sleep(100 * time.Millisecond) klog.Info("Task ", taskID, " processed") - return nil + return nil, nil } // 准备多个个任务 @@ -43,4 +45,4 @@ func TestTaskPool1(t *testing.T) { } pool.Run() // 启动 -} \ No newline at end of file +} diff --git a/example/example2_test.go b/example/example2_test.go index 920559d..074b4be 100644 --- a/example/example2_test.go +++ b/example/example2_test.go @@ -22,13 +22,20 @@ func TestTaskPool2(t *testing.T) { // 建立一个池, // input:池数量 - pool := workerpool.NewPool(5) + //pool := workerpool.NewPool(5) + pool := workerpool.NewPool(5, workerpool.WithTimeout(1), workerpool.WithErrorCallback(func(err error) { + if err != nil { + panic(err) + } + }), workerpool.WithResultCallback(func(i interface{}) { + fmt.Println("result: ", i) + })) // 准备100个任务 for i := 1; i <= 100; i++ { // 需要做的任务 - task := workerpool.NewTaskInstance(fmt.Sprintf("task-%v", i), i, func(data interface{}) error { + task := workerpool.NewTaskInstance(fmt.Sprintf("task-%v", i), i, func(data interface{}) (interface{}, error) { taskID := data.(int) /* @@ -36,7 +43,7 @@ func TestTaskPool2(t *testing.T) { */ time.Sleep(100 * time.Millisecond) klog.Info("Task ", taskID, " processed") - return nil + return nil, nil }) // 所有的任务放入list中 @@ -49,7 +56,7 @@ func TestTaskPool2(t *testing.T) { for { taskID := rand.Intn(100) + 20 - // 模拟一个退出条件 + //// 模拟一个退出条件 if taskID%7 == 0 { klog.Info("taskID: ", taskID, "pool stop!") pool.StopBackground() @@ -58,15 +65,15 @@ func TestTaskPool2(t *testing.T) { time.Sleep(time.Duration(rand.Intn(5)) * time.Second) // 模拟后续加入pool - task := workerpool.NewTaskInstance(fmt.Sprintf("task-%v", taskID), taskID, func(data interface{}) error { + task := workerpool.NewTaskInstance(fmt.Sprintf("task-%v", taskID), taskID, func(data interface{}) (interface{}, error) { taskID := data.(int) - time.Sleep(100 * time.Millisecond) + time.Sleep(3 * time.Second) klog.Info("Task ", taskID, " processed") - return nil + return nil, nil }) pool.AddTask(task) } fmt.Println("finished...") -} \ No newline at end of file +} diff --git a/example/http_example/cmd/cmd.go b/example/http_example/cmd/cmd.go index c20cecc..0db6c58 100644 --- a/example/http_example/cmd/cmd.go +++ b/example/http_example/cmd/cmd.go @@ -28,4 +28,4 @@ func Execute() { fmt.Printf("cmd err: %s\n", err) os.Exit(1) } -} \ No newline at end of file +} diff --git a/example/http_example/pkg/server/model/task.go b/example/http_example/pkg/server/model/task.go index d3b4d9c..cb93716 100644 --- a/example/http_example/pkg/server/model/task.go +++ b/example/http_example/pkg/server/model/task.go @@ -31,17 +31,17 @@ func (my *MyTask) ChooseTaskType() { } } -func (my *MyTask) Execute() error { +func (my *MyTask) Execute() (interface{}, error) { my.Status = TaskRunning if err := my.f(my.Input); err != nil { my.Err = err my.Status = TaskFail - return err + return nil, err } my.Status = TaskSuccess - return nil + return nil, nil } func (my *MyTask) GetTaskName() string { diff --git a/example/scheduler/scheduler_test.go b/example/scheduler/scheduler_test.go index 1fcc72e..0c2bc48 100644 --- a/example/scheduler/scheduler_test.go +++ b/example/scheduler/scheduler_test.go @@ -12,9 +12,9 @@ func TestScheduler(t *testing.T) { s.Start() - tsk := workerpool.NewTaskInstance("task1", "aaa", func(i interface{}) error { + tsk := workerpool.NewTaskInstance("task1", "aaa", func(i interface{}) (interface{}, error) { fmt.Println(i) - return nil + return nil, nil }) s.AddTask(tsk) diff --git a/pkg/workerpool/option.go b/pkg/workerpool/option.go new file mode 100644 index 0000000..42e20b3 --- /dev/null +++ b/pkg/workerpool/option.go @@ -0,0 +1,27 @@ +package workerpool + +import "time" + +// Option 选项模式 +type Option func(pool *Pool) + +// WithTimeout 设置超时时间 +func WithTimeout(timeout time.Duration) Option { + return func(p *Pool) { + p.timeout = timeout + } +} + +// WithResultCallback 设置结果回调方法 +func WithResultCallback(callback func(interface{})) Option { + return func(p *Pool) { + p.resultCallback = callback + } +} + +// WithErrorCallback 设置错误回调方法 +func WithErrorCallback(callback func(error)) Option { + return func(p *Pool) { + p.errorCallback = callback + } +} diff --git a/pkg/workerpool/pool.go b/pkg/workerpool/pool.go index c443f20..5b5c19c 100644 --- a/pkg/workerpool/pool.go +++ b/pkg/workerpool/pool.go @@ -2,6 +2,7 @@ package workerpool import ( "k8s.io/klog/v2" + "math/rand" "sync" "time" ) @@ -10,25 +11,38 @@ import ( type Pool struct { // list 装task Tasks []Task + // Workers 列表 Workers []*worker // 工作池数量 concurrency int // collector 用来输入所有Task对象的chan collector chan Task // runBackground 后台运行时,结束时需要传入的标示 - runBackground chan bool - wg sync.WaitGroup + runBackground chan bool + // timeout 超时时间 + timeout time.Duration + // errorCallback 当任务发生错误时的回调方法 + errorCallback func(err error) + // resultCallback 当任务有结果时的回调方法 + resultCallback func(result interface{}) + wg sync.WaitGroup } // NewPool 建立一个pool -func NewPool(concurrency int) *Pool { - return &Pool{ +func NewPool(concurrency int, opts ...Option) *Pool { + p := &Pool{ Tasks: make([]Task, 0), Workers: make([]*worker, 0), concurrency: concurrency, collector: make(chan Task, 10), runBackground: make(chan bool), } + + for _, opt := range opts { + opt(p) + } + + return p } // AddGlobalQueue 加入工作池的全局队列,静态加入,用于启动工作池前的任务加入时使用, @@ -44,27 +58,44 @@ func (p *Pool) Run() { // 总共会开启p.concurrency个goroutine // 启动pool中的每个worker都传入collector chan for i := 1; i <= p.concurrency; i++ { - worker := newWorker(p.collector, i) - worker.start(&p.wg) + wr := newWorker(i, p.timeout, p.errorCallback, p.resultCallback) + p.Workers = append(p.Workers, wr) + wr.start(&p.wg) } - for len(p.Tasks) == 0 { - klog.Error("no task in global queue...") - time.Sleep(time.Millisecond) + // 如果全局队列没任务,提示一下 + if len(p.Tasks) == 0 { + klog.Info("no task in global queue...") } + go p.dispatch() + // 把放在tasks列表的的任务放入collector for i := range p.Tasks { p.collector <- p.Tasks[i] - } // 注意,这里需要close chan。 close(p.collector) + // 阻塞,等待所有的goroutine执行完毕 p.wg.Wait() } +// dispatch 由pool chan中不断分发给worker chan +// 使用随机分配的方式 +func (p *Pool) dispatch() { + for task := range p.collector { + index := rand.Intn(p.concurrency) + p.Workers[index].taskChan <- task + } + + for _, v := range p.Workers { + close(v.taskChan) + } + +} + // AddTask 把任务放入chan,当工作池启动后,动态加入使用 func (p *Pool) AddTask(task Task) { // 放入chan @@ -83,15 +114,17 @@ func (p *Pool) RunBackground() { // 启动workers 数量: p.concurrency for i := 1; i <= p.concurrency; i++ { - workers := newWorker(p.collector, i) - p.Workers = append(p.Workers, workers) + wk := newWorker(i, p.timeout, p.errorCallback, p.resultCallback) + p.Workers = append(p.Workers, wk) - go workers.startBackground() + go wk.startBackground() } + go p.dispatch() + + // 如果全局队列没任务,提示一下 if len(p.Tasks) == 0 { - klog.Error("no task in global queue...") - time.Sleep(time.Millisecond) + klog.Info("no task in global queue...") } for i := range p.Tasks { @@ -110,4 +143,4 @@ func (p *Pool) StopBackground() { p.Workers[i].stop() } p.runBackground <- true -} \ No newline at end of file +} diff --git a/pkg/workerpool/task.go b/pkg/workerpool/task.go index 453f145..d82df38 100644 --- a/pkg/workerpool/task.go +++ b/pkg/workerpool/task.go @@ -7,20 +7,20 @@ package workerpool // Task 任务接口,由工作池抽象出的具体执行单元, // 当pool启动时,会从chan中不断读取Task接口对象执行 type Task interface { - Execute() error + Execute() (interface{}, error) GetTaskName() string } // TaskInstance 一个具体任务需求 type TaskInstance struct { Name string - Err error // 返回错误 - Data interface{} // 真正的处理数据 - f func(interface{}) error // 处理函数 + Err error // 返回错误 + Data interface{} // 真正的处理数据 + f func(interface{}) (interface{}, error) // 处理函数 } // NewTaskInstance 建立任务 -func NewTaskInstance(name string, data interface{}, f func(interface{}) error) *TaskInstance { +func NewTaskInstance(name string, data interface{}, f func(interface{}) (interface{}, error)) *TaskInstance { return &TaskInstance{ Name: name, Data: data, @@ -28,12 +28,15 @@ func NewTaskInstance(name string, data interface{}, f func(interface{}) error) * } } -func (t *TaskInstance) Execute() error { - t.Err = t.f(t.Data) // 执行任务。如果任务执行错误,赋值err - return nil +func (t *TaskInstance) Execute() (interface{}, error) { + result, err := t.f(t.Data) // 执行任务。如果任务执行错误,赋值err + if err != nil { + t.Err = err + return nil, err + } + return result, nil } func (t *TaskInstance) GetTaskName() string { return t.Name } - diff --git a/pkg/workerpool/worker.go b/pkg/workerpool/worker.go index 7178a4e..a4f8be0 100644 --- a/pkg/workerpool/worker.go +++ b/pkg/workerpool/worker.go @@ -1,28 +1,82 @@ package workerpool import ( + "context" + "fmt" "k8s.io/klog/v2" "sync" + "time" ) // worker 执行任务的消费者 type worker struct { ID int // 消费者的id - // 等待处理的任务chan (每个worker都有一个自己的chan) + // taskChan 等待处理的任务chan (每个worker都有一个自己的chan) taskChan chan Task - // 停止通知 + // quit 停止通知 quit chan bool + // timeout 超时时间 + timeout time.Duration + // errorCallback 当任务发生错误时的回调方法 + errorCallback func(err error) + // resultCallback 当任务有结果时的回调方法 + resultCallback func(result interface{}) } -// newWorker 建立新的消费者 -func newWorker(channel chan Task, ID int) *worker { +// newWorker 创建worker +func newWorker(ID int, timeout time.Duration, errorCallback func(err error), resultCallback func(interface{})) *worker { return &worker{ - ID: ID, - taskChan: channel, - quit: make(chan bool), + ID: ID, + taskChan: make(chan Task, 100), + quit: make(chan bool), + timeout: timeout, + errorCallback: errorCallback, + resultCallback: resultCallback, } } +// executeTask 执行任务 +func (wr *worker) executeTask(task Task) (interface{}, error) { + var err error + var result interface{} + if wr.timeout > 0 { + result, err = wr.executeTaskWithTimeout(task) + } else { + result, err = wr.executeTaskWithoutTimeout(task) + } + return result, err +} + +// executeTaskWithTimeout 执行任务有超时的情况 +func (wr *worker) executeTaskWithTimeout(task Task) (interface{}, error) { + + ctx, cancel := context.WithTimeout(context.Background(), wr.timeout*time.Second) + defer cancel() + + var result interface{} + var err error + done := make(chan struct{}) + + // 异步执行并等待 + go func() { + result, err = task.Execute() + close(done) + }() + + // 阻塞等待超时先到还是任务先执行完成 + select { + case <-done: + return result, err + case <-ctx.Done(): + return nil, fmt.Errorf("task timed out...") + } +} + +// executeTaskWithoutTimeout 执行任务 +func (wr *worker) executeTaskWithoutTimeout(task Task) (interface{}, error) { + return task.Execute() +} + // start 执行任务,遍历taskChan,每个worker都启一个goroutine执行。 func (wr *worker) start(wg *sync.WaitGroup) { klog.Info("Starting worker: ", wr.ID) @@ -33,7 +87,8 @@ func (wr *worker) start(wg *sync.WaitGroup) { // 不断从chan中取出task执行 for task := range wr.taskChan { klog.Info("worker: ", wr.ID, ", processes task: ", task.GetTaskName()) - task.Execute() + result, err := wr.executeTask(task) + wr.handleResult(result, err) } }() } @@ -45,17 +100,27 @@ func (wr *worker) startBackground() { select { case task := <-wr.taskChan: klog.Info("worker: ", wr.ID, ", processes task: ", task.GetTaskName()) - task.Execute() + result, err := wr.executeTask(task) + wr.handleResult(result, err) case <-wr.quit: return } } +} +// handleResult 处理任务结束的方法 +func (wr *worker) handleResult(result interface{}, err error) { + if err != nil && wr.errorCallback != nil { + wr.errorCallback(err) + } else if wr.resultCallback != nil { + wr.resultCallback(result) + } } +// stop 停止worker func (wr *worker) stop() { klog.Info("Closing worker: ", wr.ID) go func() { wr.quit <- true }() -} \ No newline at end of file +}