Skip to content

Commit

Permalink
Merge pull request #8 from KlassnayaAfrodita/patch-2
Browse files Browse the repository at this point in the history
Add Worker Pool concurrency.go
  • Loading branch information
hokamsingh authored Oct 9, 2024
2 parents d6efd97 + 942e715 commit e650b24
Showing 1 changed file with 98 additions and 38 deletions.
136 changes: 98 additions & 38 deletions internal/core/concurrency/concurrency.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package concurrency

import (
"context"
"errors"
"sync"
)

Expand All @@ -13,29 +14,12 @@ type Task struct {
fn TaskFunc
}

// TaskFunc defines the type for the task function that returns a result and an error.
// @callback TaskFunc
// @param {context.Context} ctx - The context for the task, used for cancellation and deadlines.
// @returns {interface{}, error} The result of the task and an error, if any.
//
// Example:
//
// taskFunc := func(ctx context.Context) (interface{}, error) {
// return "task result", nil
// }
// NewTask creates a new Task.
func NewTask(fn TaskFunc) *Task {
return &Task{fn: fn}
}

// Execute runs the task function and returns the result or an error.
// Example:
//
// result, err := task.Execute(context.Background())
// if err != nil {
// log.Fatal(err)
// }
//
// fmt.Println(result)
func (t *Task) Execute(ctx context.Context) (interface{}, error) {
return t.fn(ctx)
}
Expand All @@ -48,15 +32,78 @@ const (
Sequential ExecutionMode = 1
)

// WorkerPool manages a fixed number of workers to process tasks concurrently.
type WorkerPool struct {
taskChan chan *Task
resultChan chan result
workerCount int
wg sync.WaitGroup
}

type result struct {
index int
output interface{}
err error
}

// NewWorkerPool initializes a worker pool with the specified number of workers.
func NewWorkerPool(workerCount int) *WorkerPool {
return &WorkerPool{
taskChan: make(chan *Task),
resultChan: make(chan result),
workerCount: workerCount,
}
}

// Run starts the workers in the pool.
func (wp *WorkerPool) Run(ctx context.Context) {
for i := 0; i < wp.workerCount; i++ {
wp.wg.Add(1)
go func() {
defer wp.wg.Done()
for task := range wp.taskChan {
select {
case <-ctx.Done():
return
default:
output, err := task.Execute(ctx)
wp.resultChan <- result{output: output, err: err}
}
}
}()
}
}

// Stop waits for all workers to finish.
func (wp *WorkerPool) Stop() {
close(wp.taskChan)
wp.wg.Wait()
close(wp.resultChan)
}

// Submit adds a task to the task channel.
func (wp *WorkerPool) Submit(task *Task, index int) {
wp.taskChan <- task
}

// Results returns the result channel to collect task outputs and errors.
func (wp *WorkerPool) Results() <-chan result {
return wp.resultChan
}

// TaskManager manages and executes tasks concurrently or sequentially.
type TaskManager struct {
tasks []*Task
mode ExecutionMode
tasks []*Task
mode ExecutionMode
workerCount int
}

// NewTaskManager creates a new TaskManager with the specified execution mode.
func NewTaskManager(mode ExecutionMode) *TaskManager {
return &TaskManager{mode: mode}
// NewTaskManager creates a new TaskManager with the specified execution mode and optional worker count.
func NewTaskManager(mode ExecutionMode, workerCount int) *TaskManager {
if workerCount <= 0 {
workerCount = 10
}
return &TaskManager{mode: mode, workerCount: workerCount}
}

// AddTask adds a task to the manager.
Expand All @@ -72,28 +119,41 @@ func (tm *TaskManager) Run(ctx context.Context) ([]interface{}, error) {
return tm.runSequential(ctx)
}

// runParallel executes all tasks concurrently and collects the results.
// runParallel executes all tasks concurrently using a worker pool and collects the results.
func (tm *TaskManager) runParallel(ctx context.Context) ([]interface{}, error) {
var wg sync.WaitGroup
pool := NewWorkerPool(tm.workerCount)
results := make([]interface{}, len(tm.tasks))
errChan := make(chan error, len(tm.tasks))
errChan := make(chan error, 1) // Buffer size 1 for first error

// Start worker pool
pool.Run(ctx)

// Submit tasks to the worker pool
for i, task := range tm.tasks {
wg.Add(1)
go func(i int, t *Task) {
defer wg.Done()
result, err := t.Execute(ctx)
if err != nil {
errChan <- err
return
}
results[i] = result
go func(index int, task *Task) {
pool.Submit(task, index)
}(i, task)
}

wg.Wait()
// Collect results
go func() {
for res := range pool.Results() {
if res.err != nil {
select {
case errChan <- res.err: // pass only first error
default:
}
} else {
results[res.index] = res.output
}
}
}()

// Stop the worker pool and wait for results
pool.Stop()
close(errChan)

// Check for errors
if len(errChan) > 0 {
return nil, <-errChan
}
Expand Down Expand Up @@ -121,8 +181,8 @@ type TaskBuilder struct {
}

// NewTaskBuilder creates a new TaskBuilder with the specified execution mode.
func NewTaskBuilder(mode ExecutionMode) *TaskBuilder {
return &TaskBuilder{tm: NewTaskManager(mode)}
func NewTaskBuilder(mode ExecutionMode, workerCount int) *TaskBuilder {
return &TaskBuilder{tm: NewTaskManager(mode, workerCount)}
}

// Add adds a new TaskFunc to the builder.
Expand Down

0 comments on commit e650b24

Please sign in to comment.