diff --git a/internal/core/concurrency/concurrency.go b/internal/core/concurrency/concurrency.go index 325857a..cdb67c2 100644 --- a/internal/core/concurrency/concurrency.go +++ b/internal/core/concurrency/concurrency.go @@ -37,6 +37,7 @@ type WorkerPool struct { resultChan chan result workerCount int wg sync.WaitGroup + once sync.Once // Used to ensure resultChan is closed only once } type result struct { @@ -55,7 +56,7 @@ func NewWorkerPool(workerCount int) *WorkerPool { } // Run starts the workers in the pool. -func (wp *WorkerPool) Run(ctx context.Context) { +func (wp *WorkerPool) Run(ctx context.Context, taskIndexes map[*Task]int) { for i := 0; i < wp.workerCount; i++ { wp.wg.Add(1) go func() { @@ -66,22 +67,24 @@ func (wp *WorkerPool) Run(ctx context.Context) { return default: output, err := task.Execute(ctx) - wp.resultChan <- result{output: output, err: err} + wp.resultChan <- result{index: taskIndexes[task], output: output, err: err} } } }() } } -// Stop waits for all workers to finish. +// Stop closes the task channel and waits for all workers to finish. func (wp *WorkerPool) Stop() { - close(wp.taskChan) - wp.wg.Wait() - close(wp.resultChan) + close(wp.taskChan) // No more tasks can be submitted + wp.wg.Wait() // Wait for all workers to finish + wp.once.Do(func() { + close(wp.resultChan) // Close result channel only once after workers are done + }) } // Submit adds a task to the task channel. -func (wp *WorkerPool) Submit(task *Task, index int) { +func (wp *WorkerPool) Submit(task *Task) { wp.taskChan <- task } @@ -122,35 +125,37 @@ func (tm *TaskManager) Run(ctx context.Context) ([]interface{}, error) { func (tm *TaskManager) runParallel(ctx context.Context) ([]interface{}, error) { pool := NewWorkerPool(tm.workerCount) results := make([]interface{}, len(tm.tasks)) - errChan := make(chan error, 1) // Buffer size 1 for first error + errChan := make(chan error, 1) // Buffer size 1 for first error + var mu sync.Mutex // Protects access to the results slice + doneChan := make(chan struct{}) // To signal when results collection is done // Start worker pool - pool.Run(ctx) + taskIndexes := make(map[*Task]int) + for i, task := range tm.tasks { + taskIndexes[task] = i + } + pool.Run(ctx, taskIndexes) // Submit tasks to the worker pool - for i, task := range tm.tasks { - go func(index int, task *Task) { - select { - case <-ctx.Done(): - errChan <- ctx.Err() - default: - pool.Submit(task, index) - } - }(i, task) + for _, task := range tm.tasks { + pool.Submit(task) } // Collect results go func() { for res := range pool.Results() { + mu.Lock() if res.err != nil { select { - case errChan <- res.err: // pass only first error + case errChan <- res.err: default: } } else { results[res.index] = res.output } + mu.Unlock() } + close(doneChan) // Close doneChan when results collection is complete }() // Stop the worker pool and wait for results @@ -158,10 +163,12 @@ func (tm *TaskManager) runParallel(ctx context.Context) ([]interface{}, error) { close(errChan) // Check for errors - if len(errChan) > 0 { - return nil, <-errChan + select { + case err := <-errChan: + return nil, err + default: + return results, nil } - return results, nil } // runSequential executes all tasks one by one and collects the results.