diff --git a/ggml.c b/ggml.c index 8a60bc3831c6d..53a797f2c1eb4 100644 --- a/ggml.c +++ b/ggml.c @@ -9240,14 +9240,14 @@ typedef pthread_t ggml_thread_t; #endif struct ggml_compute_state_shared { - ggml_lock_t spin; - int n_threads; // synchronization primitives - atomic_int n_ready; - atomic_bool has_work; - atomic_bool stop; // stop all threads + // The `flag` works as work counter + stop indicator. + // > 0: main thread stores initial value, every worker decreases it by 1. + // = 0: all done. + // < 0: stop now. + atomic_int flag; }; struct ggml_compute_state { @@ -9262,45 +9262,15 @@ struct ggml_compute_state { static thread_ret_t ggml_graph_compute_thread(void * data) { struct ggml_compute_state * state = (struct ggml_compute_state *) data; - const int n_threads = state->shared->n_threads; - while (true) { - if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) { - atomic_store(&state->shared->has_work, false); - } else { - while (atomic_load(&state->shared->has_work)) { - if (atomic_load(&state->shared->stop)) { - return 0; - } - ggml_lock_lock (&state->shared->spin); - ggml_lock_unlock(&state->shared->spin); - } - } - - atomic_fetch_sub(&state->shared->n_ready, 1); - - // wait for work - while (!atomic_load(&state->shared->has_work)) { - if (atomic_load(&state->shared->stop)) { - return 0; - } - ggml_lock_lock (&state->shared->spin); - ggml_lock_unlock(&state->shared->spin); - } - - // check if we should stop - if (atomic_load(&state->shared->stop)) { - break; - } - - if (state->node) { - if (state->params.ith < state->params.nth) { + int flag = atomic_load(&state->shared->flag); + if (flag < 0) return NULL; // stop + if (flag > 0) { // pending works + if (state->node) { // my work ggml_compute_forward(&state->params, state->node); + state->node = NULL; + atomic_fetch_sub(&state->shared->flag, 1); // done } - - state->node = NULL; - } else { - break; } } @@ -9311,20 +9281,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) const int n_threads = cgraph->n_threads; struct ggml_compute_state_shared state_shared = { - /*.spin =*/ GGML_LOCK_INITIALIZER, - /*.n_threads =*/ n_threads, - /*.n_ready =*/ 0, - /*.has_work =*/ false, - /*.stop =*/ false, + .n_threads = n_threads, + .flag = 0, }; struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL; // create thread pool if (n_threads > 1) { - ggml_lock_init(&state_shared.spin); - - atomic_store(&state_shared.has_work, true); - for (int j = 0; j < n_threads - 1; j++) { workers[j] = (struct ggml_compute_state) { .thrd = 0, @@ -9579,17 +9542,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) // COMPUTE if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - // launch thread pool - for (int j = 0; j < n_threads - 1; j++) { + for (int j = 0; j < node->n_tasks - 1; j++) { workers[j].params = (struct ggml_compute_params) { .type = GGML_TASK_COMPUTE, .ith = j + 1, @@ -9600,14 +9554,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) workers[j].node = node; } - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) > 0) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - atomic_store(&state_shared.has_work, true); + atomic_store(&state_shared.flag, node->n_tasks - 1); } params.type = GGML_TASK_COMPUTE; @@ -9615,54 +9562,16 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) // wait for thread pool if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) != 0) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } + while (atomic_load(&state_shared.flag) != 0) {} } // FINALIZE if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - // launch thread pool - for (int j = 0; j < n_threads - 1; j++) { - workers[j].params = (struct ggml_compute_params) { - .type = GGML_TASK_FINALIZE, - .ith = j + 1, - .nth = node->n_tasks, - .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, - .wdata = cgraph->work ? cgraph->work->data : NULL, - }; + for (int j = 0; j < node->n_tasks-1; j++) { + workers[j].params.type = GGML_TASK_FINALIZE; workers[j].node = node; } - - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) > 0) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - atomic_store(&state_shared.has_work, true); + atomic_store(&state_shared.flag, node->n_tasks-1); } params.type = GGML_TASK_FINALIZE; @@ -9670,21 +9579,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) // wait for thread pool if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) != 0) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } + while (atomic_load(&state_shared.flag) != 0) {} } // performance stats (node) @@ -9700,16 +9595,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) // join thread pool if (n_threads > 1) { - atomic_store(&state_shared.stop, true); - atomic_store(&state_shared.has_work, true); + atomic_store(&state_shared.flag, -1); for (int j = 0; j < n_threads - 1; j++) { int rc = ggml_thread_join(workers[j].thrd, NULL); GGML_ASSERT(rc == 0); UNUSED(rc); } - - ggml_lock_destroy(&state_shared.spin); } // performance stats (graph)