@@ -100,6 +100,12 @@ static int sched_yield (void) {
100
100
Sleep (0);
101
101
return 0;
102
102
}
103
+
104
+ /* On windows we do not have semaphore.h we thus
105
+ rely on active polling in the threadpool
106
+ TODO: find better ? */
107
+ #define GGML_THREADPOOL_ACTIVE_POLL
108
+
103
109
#else
104
110
#include <pthread.h>
105
111
#include <stdatomic.h>
@@ -114,6 +120,68 @@ typedef void * thread_ret_t;
114
120
115
121
typedef pthread_t ggml_thread_t;
116
122
123
+ // Implementing a persistent thread-pool
124
+
125
+ #ifndef GGML_THREADPOOL_ACTIVE_POLL
126
+ #include <semaphore.h>
127
+ #endif
128
+
129
+ #define GGML_THREAD_POOL_SIZE 512
130
+
131
+ struct ggml_thread_pool_context {
132
+ /** Function to run */
133
+ void * (*fn)(void *);
134
+ /** Argument to pass*/
135
+ void * arg;
136
+ /** Return value */
137
+ void * ret;
138
+
139
+ /** At 1 if the thread is running*/
140
+ volatile int executing;
141
+ /** Stop condition */
142
+ volatile int running;
143
+ #ifndef GGML_THREADPOOL_ACTIVE_POLL
144
+ /** Used to pause threads in POSIX systems */
145
+ sem_t sem;
146
+ #endif
147
+ /** Used for identifying idle threads */
148
+ volatile int flag;
149
+
150
+ /** Thread associated with this context */
151
+ ggml_thread_t thread;
152
+ /** Threads are created lazily using this flag */
153
+ short has_thread;
154
+ };
155
+
156
+ void ggml_thread_pool_context_init(struct ggml_thread_pool_context * ctx);
157
+
158
+ /** Main structure for the thread pool*/
159
+ struct ggml_thread_pool {
160
+ /** Each context is a lazily called thread */
161
+ struct ggml_thread_pool_context ctx[GGML_THREAD_POOL_SIZE];
162
+ };
163
+
164
+ /** Static instance of the GGML thread pool */
165
+ static struct ggml_thread_pool __thp;
166
+
167
+ /** This is the object representing a thread part of the threadpool*/
168
+ typedef struct ggml_thread_pool_thread_s {
169
+ /** Id of the thread (offset in threadpool array) -1 if external */
170
+ int id;
171
+ /** Handle for threads created externally */
172
+ ggml_thread_t th;
173
+ } ggml_thread_pool_thread_t;
174
+
175
+ /** This is the mainloop for threads in the threadpool */
176
+ static void * ggml_thread_pool_main(void * pctx);
177
+
178
+ /** Called once to initialize the threadpool */
179
+ void ggml_thread_pool_init(void);
180
+
181
+ /* known_index is a recommendation to try a given thread ID it allows to limit locking contention */
182
+ int ggml_thread_pool_create_thread(ggml_thread_pool_thread_t * th, void * (*fn)(void *), void * arg, int known_index);
183
+ int ggml_thread_pool_join_thread(ggml_thread_pool_thread_t th, void **retval);
184
+
117
185
#ifdef GGML_USE_CPU_HBM
118
186
#include <hbwmalloc.h>
119
187
#endif
@@ -1579,7 +1647,7 @@ struct ggml_compute_state_shared {
1579
1647
};
1580
1648
1581
1649
struct ggml_compute_state {
1582
- ggml_thread_t thrd;
1650
+ ggml_thread_pool_thread_t thrd;
1583
1651
int ith;
1584
1652
struct ggml_compute_state_shared* shared;
1585
1653
enum ggml_status ec;
@@ -3130,6 +3198,7 @@ static inline int ggml_up(int n, int m) {
3130
3198
3131
3199
////////////////////////////////////////////////////////////////////////////////
3132
3200
3201
+
3133
3202
struct ggml_context * ggml_init(struct ggml_init_params params) {
3134
3203
// make this function thread safe
3135
3204
ggml_critical_section_start();
@@ -3139,6 +3208,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3139
3208
if (is_first_call) {
3140
3209
// initialize time system (required on Windows)
3141
3210
ggml_time_init();
3211
+ ggml_thread_pool_init();
3142
3212
3143
3213
// initialize GELU, Quick GELU, SILU and EXP F32 tables
3144
3214
{
@@ -3249,6 +3319,7 @@ void ggml_free(struct ggml_context * ctx) {
3249
3319
// make this function thread safe
3250
3320
ggml_critical_section_start();
3251
3321
3322
+
3252
3323
bool found = false;
3253
3324
3254
3325
for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
@@ -19305,6 +19376,130 @@ typedef int ggml_lock_t;
19305
19376
19306
19377
#endif
19307
19378
19379
+ /* Thread Pool Implementation */
19380
+
19381
+ void ggml_thread_pool_context_init(struct ggml_thread_pool_context * ctx) {
19382
+ memset(ctx, 0, sizeof(struct ggml_thread_pool_context));
19383
+
19384
+ ctx->running = 1;
19385
+ /* Initially no threads are created */
19386
+ ctx->has_thread = 0;
19387
+ #ifndef GGML_THREADPOOL_ACTIVE_POLL
19388
+ sem_init(&ctx->sem, 0, 0);
19389
+ #endif
19390
+ }
19391
+
19392
+ static void * ggml_thread_pool_main(void * pctx) {
19393
+ struct ggml_thread_pool_context * ctx = (struct ggml_thread_pool_context*)pctx;
19394
+
19395
+ while(ctx->running) {
19396
+ /* Here we wait for start notification */
19397
+ #ifdef GGML_THREADPOOL_ACTIVE_POLL
19398
+ /* Wait for run flag */
19399
+ while(atomic_load(&ctx->executing) == 0) {
19400
+ sched_yield();
19401
+ }
19402
+ #else
19403
+ sem_wait(&ctx->sem);
19404
+ #endif
19405
+ ctx->ret = NULL;
19406
+
19407
+ if(ctx->fn) {
19408
+ /* Call the actual function */
19409
+ ctx->ret = (ctx->fn)(ctx->arg);
19410
+ }
19411
+
19412
+ /* Flag done for join */
19413
+ atomic_store(&ctx->executing, 0);
19414
+
19415
+ }
19416
+
19417
+ return NULL;
19418
+ }
19419
+
19420
+ void ggml_thread_pool_init(void) {
19421
+ int i = 0;
19422
+
19423
+ for(i = 0 ; i < GGML_THREAD_POOL_SIZE ; i++) {
19424
+ /* Thread is running */
19425
+ ggml_thread_pool_context_init(&__thp.ctx[i]);
19426
+ }
19427
+ }
19428
+
19429
+ int ggml_thread_pool_create_thread(ggml_thread_pool_thread_t * th, void * (*fn)(void *), void * arg, int known_index) {
19430
+ /* Find a free thread */
19431
+ int i = 0;
19432
+
19433
+ if(known_index < 0) {
19434
+ known_index = 0;
19435
+ }
19436
+
19437
+ assert(known_index < GGML_THREAD_POOL_SIZE);
19438
+
19439
+ for( i = known_index ; i < GGML_THREAD_POOL_SIZE + known_index; i++) {
19440
+ int zero = 0;
19441
+ int targ = i % GGML_THREAD_POOL_SIZE;
19442
+ struct ggml_thread_pool_context * ctx = &__thp.ctx[targ];
19443
+
19444
+ if( atomic_compare_exchange_weak(&ctx->flag, &zero, 1) ) {
19445
+ /* We have the thread */
19446
+ ctx->fn = fn;
19447
+ ctx->arg = arg;
19448
+
19449
+ /* Save ID*/
19450
+ th->id = i;
19451
+
19452
+ atomic_store(&ctx->executing, 1);
19453
+
19454
+ /* Is this thread already created ? */
19455
+ if(!ctx->has_thread) {
19456
+ ggml_thread_create(&ctx->thread, NULL, ggml_thread_pool_main, &__thp.ctx[i]);
19457
+ ctx->has_thread = 1;
19458
+ }
19459
+
19460
+ #ifndef GGML_THREADPOOL_ACTIVE_POLL
19461
+ /* Signal start */
19462
+ sem_post(&ctx->sem);
19463
+ #endif
19464
+ return 0;
19465
+ }
19466
+ }
19467
+
19468
+ /* if we are here we failed to get from pool create a "normal" thread and flag it withj ID -1 for handling in join */
19469
+ th->id = -1;
19470
+ return ggml_thread_create(&th->th, NULL, fn, arg);
19471
+ }
19472
+
19473
+ int ggml_thread_pool_join_thread(ggml_thread_pool_thread_t th, void **retval)
19474
+ {
19475
+ /* Normal thread case */
19476
+ if(th.id < 0) {
19477
+ return ggml_thread_join(th.th, retval);
19478
+ }
19479
+
19480
+ struct ggml_thread_pool_context * ctx = &__thp.ctx[th.id];
19481
+
19482
+ /* Thread must be taken */
19483
+ assert(atomic_load(&ctx->flag) == 1);
19484
+
19485
+ while(atomic_load(&ctx->executing)) {
19486
+ sched_yield();
19487
+ }
19488
+
19489
+ /* Done executing if we are here*/
19490
+ if(retval) {
19491
+ *retval = ctx->ret;
19492
+ }
19493
+
19494
+ ctx->fn = NULL;
19495
+ ctx->arg = NULL;
19496
+
19497
+ /* Set the thread free */
19498
+ atomic_store(&ctx->flag, 0);
19499
+
19500
+ return 0;
19501
+ }
19502
+
19308
19503
// Android's libc implementation "bionic" does not support setting affinity
19309
19504
#if defined(__gnu_linux__)
19310
19505
static void set_numa_thread_affinity(int thread_n) {
@@ -20061,13 +20256,13 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
20061
20256
if (n_threads > 1) {
20062
20257
for (int j = 1; j < n_threads; ++j) {
20063
20258
workers[j] = (struct ggml_compute_state) {
20064
- .thrd = 0 ,
20259
+ .thrd = {0} ,
20065
20260
.ith = j,
20066
20261
.shared = &state_shared,
20067
20262
.ec = GGML_STATUS_SUCCESS,
20068
20263
};
20069
20264
20070
- const int rc = ggml_thread_create (&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
20265
+ const int rc = ggml_thread_pool_create_thread (&workers[j].thrd, ggml_graph_compute_thread, &workers[j], j );
20071
20266
GGML_ASSERT(rc == 0);
20072
20267
UNUSED(rc);
20073
20268
}
@@ -20090,7 +20285,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
20090
20285
// join or kill thread pool
20091
20286
if (n_threads > 1) {
20092
20287
for (int j = 1; j < n_threads; j++) {
20093
- const int rc = ggml_thread_join (workers[j].thrd, NULL);
20288
+ const int rc = ggml_thread_pool_join_thread (workers[j].thrd, NULL);
20094
20289
GGML_ASSERT(rc == 0);
20095
20290
if (workers[j].ec != GGML_STATUS_SUCCESS)
20096
20291
compute_status = workers[j].ec;
0 commit comments