diff --git a/kthread.c b/kthread.c index f991714c..c0d90933 100644 --- a/kthread.c +++ b/kthread.c @@ -64,6 +64,110 @@ void kt_for(int n_threads, void (*func)(void*,long,int), void *data, long n) } } +/*************************** + * kt_for with thread pool * + ***************************/ + +struct kt_forpool_t; + +typedef struct { + struct kt_forpool_t *t; + long i; + int action; +} kto_worker_t; + +typedef struct kt_forpool_t { + int n_threads, n_done; + long n; + pthread_t *tid; + kto_worker_t *w; + void (*func)(void*,long,int); + void *data; + pthread_mutex_t mu_m, mu_s; + pthread_cond_t cv_m, cv_s; +} kt_forpool_t; + +static inline long kt_fp_steal_work(kt_forpool_t *t) +{ + int i, min_i = -1; + long k, min = LONG_MAX; + for (i = 0; i < t->n_threads; ++i) + if (min > t->w[i].i) min = t->w[i].i, min_i = i; + k = __sync_fetch_and_add(&t->w[min_i].i, t->n_threads); + return k >= t->n? -1 : k; +} + +static void *kt_fp_worker(void *data) +{ + kto_worker_t *w = (kto_worker_t*)data; + kt_forpool_t *fp = w->t; + for (;;) { + long i; + int action; + pthread_mutex_lock(&fp->mu_s); + while (w->action == 0) pthread_cond_wait(&fp->cv_s, &fp->mu_s); + action = w->action; + pthread_mutex_unlock(&fp->mu_s); + if (action > 0) { + for (;;) { // process jobs allocated to this worker + i = __sync_fetch_and_add(&w->i, fp->n_threads); + if (i >= fp->n) break; + fp->func(fp->data, i, w - fp->w); + } + while ((i = kt_fp_steal_work(fp)) >= 0) // steal jobs allocated to other workers + fp->func(fp->data, i, w - fp->w); + } + w->action = 0; + if (__sync_add_and_fetch(&fp->n_done, 1) == fp->n_threads) + pthread_cond_signal(&fp->cv_m); + if (action < 0) break; + } + pthread_exit(0); +} + +void *kt_forpool_init(int n_threads) +{ + kt_forpool_t *fp; + int i; + fp = (kt_forpool_t*)calloc(1, sizeof(kt_forpool_t)); + fp->n_threads = n_threads; + fp->tid = (pthread_t*)calloc(fp->n_threads, sizeof(pthread_t)); + fp->w = (kto_worker_t*)calloc(fp->n_threads, sizeof(kto_worker_t)); + for (i = 0; i < fp->n_threads; ++i) fp->w[i].t = fp; + pthread_mutex_init(&fp->mu_m, 0); pthread_cond_init(&fp->cv_m, 0); + pthread_mutex_init(&fp->mu_s, 0); pthread_cond_init(&fp->cv_s, 0); + for (i = 0; i < fp->n_threads; ++i) pthread_create(&fp->tid[i], 0, kt_fp_worker, &fp->w[i]); + return fp; +} + +void kt_forpool_destroy(void *_fp) +{ + kt_forpool_t *fp = (kt_forpool_t*)_fp; + int i; + fp->n_done = 0; + for (i = 0; i < fp->n_threads; ++i) fp->w[i].action = -1; + pthread_cond_broadcast(&fp->cv_s); + pthread_mutex_lock(&fp->mu_m); + pthread_cond_wait(&fp->cv_m, &fp->mu_m); + pthread_mutex_unlock(&fp->mu_m); + for (i = 0; i < fp->n_threads; ++i) pthread_join(fp->tid[i], 0); + pthread_cond_destroy(&fp->cv_s); pthread_mutex_destroy(&fp->mu_s); + pthread_cond_destroy(&fp->cv_m); pthread_mutex_destroy(&fp->mu_m); + free(fp->w); free(fp->tid); free(fp); +} + +void kt_forpool(void *_fp, void (*func)(void*,long,int), void *data, long n) +{ + kt_forpool_t *fp = (kt_forpool_t*)_fp; + int i; + fp->n = n, fp->func = func, fp->data = data, fp->n_done = 0; + for (i = 0; i < fp->n_threads; ++i) fp->w[i].i = i, fp->w[i].action = 1; + pthread_cond_broadcast(&fp->cv_s); + pthread_mutex_lock(&fp->mu_m); + pthread_cond_wait(&fp->cv_m, &fp->mu_m); + pthread_mutex_unlock(&fp->mu_m); +} + /***************** * kt_pipeline() * *****************/ diff --git a/kthread.h b/kthread.h index c3cd165e..8325c3f8 100644 --- a/kthread.h +++ b/kthread.h @@ -8,6 +8,10 @@ extern "C" { void kt_for(int n_threads, void (*func)(void*,long,int), void *data, long n); void kt_pipeline(int n_threads, void *(*func)(void*, int, void*), void *shared_data, int n_steps); +void *kt_forpool_init(int n_threads); +void kt_forpool_destroy(void *_fp); +void kt_forpool(void *_fp, void (*func)(void*,long,int), void *data, long n); + #ifdef __cplusplus } #endif