Skip to content

ggml: refactor compute thread: merge three spin variables into one #816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 21 additions & 129 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -9240,14 +9240,14 @@ typedef pthread_t ggml_thread_t;
#endif

struct ggml_compute_state_shared {
ggml_lock_t spin;

int n_threads;

// synchronization primitives
atomic_int n_ready;
atomic_bool has_work;
atomic_bool stop; // stop all threads
// The `flag` works as work counter + stop indicator.
// > 0: main thread stores initial value, every worker decreases it by 1.
// = 0: all done.
// < 0: stop now.
atomic_int flag;
};

struct ggml_compute_state {
Expand All @@ -9262,45 +9262,15 @@ struct ggml_compute_state {
static thread_ret_t ggml_graph_compute_thread(void * data) {
struct ggml_compute_state * state = (struct ggml_compute_state *) data;

const int n_threads = state->shared->n_threads;

while (true) {
if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) {
atomic_store(&state->shared->has_work, false);
} else {
while (atomic_load(&state->shared->has_work)) {
if (atomic_load(&state->shared->stop)) {
return 0;
}
ggml_lock_lock (&state->shared->spin);
ggml_lock_unlock(&state->shared->spin);
}
}

atomic_fetch_sub(&state->shared->n_ready, 1);

// wait for work
while (!atomic_load(&state->shared->has_work)) {
if (atomic_load(&state->shared->stop)) {
return 0;
}
ggml_lock_lock (&state->shared->spin);
ggml_lock_unlock(&state->shared->spin);
}

// check if we should stop
if (atomic_load(&state->shared->stop)) {
break;
}

if (state->node) {
if (state->params.ith < state->params.nth) {
int flag = atomic_load(&state->shared->flag);
if (flag < 0) return NULL; // stop
if (flag > 0) { // pending works
if (state->node) { // my work
ggml_compute_forward(&state->params, state->node);
state->node = NULL;
atomic_fetch_sub(&state->shared->flag, 1); // done
}

state->node = NULL;
} else {
break;
}
}

Expand All @@ -9311,20 +9281,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
const int n_threads = cgraph->n_threads;

struct ggml_compute_state_shared state_shared = {
/*.spin =*/ GGML_LOCK_INITIALIZER,
/*.n_threads =*/ n_threads,
/*.n_ready =*/ 0,
/*.has_work =*/ false,
/*.stop =*/ false,
.n_threads = n_threads,
.flag = 0,
};
struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL;

// create thread pool
if (n_threads > 1) {
ggml_lock_init(&state_shared.spin);

atomic_store(&state_shared.has_work, true);

for (int j = 0; j < n_threads - 1; j++) {
workers[j] = (struct ggml_compute_state) {
.thrd = 0,
Expand Down Expand Up @@ -9579,17 +9542,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)

// COMPUTE
if (node->n_tasks > 1) {
if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
atomic_store(&state_shared.has_work, false);
}

while (atomic_load(&state_shared.has_work)) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}

// launch thread pool
for (int j = 0; j < n_threads - 1; j++) {
for (int j = 0; j < node->n_tasks - 1; j++) {
workers[j].params = (struct ggml_compute_params) {
.type = GGML_TASK_COMPUTE,
.ith = j + 1,
Expand All @@ -9600,91 +9554,32 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
workers[j].node = node;
}

atomic_fetch_sub(&state_shared.n_ready, 1);

while (atomic_load(&state_shared.n_ready) > 0) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}

atomic_store(&state_shared.has_work, true);
atomic_store(&state_shared.flag, node->n_tasks - 1);
}

params.type = GGML_TASK_COMPUTE;
ggml_compute_forward(&params, node);

// wait for thread pool
if (node->n_tasks > 1) {
if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
atomic_store(&state_shared.has_work, false);
}

while (atomic_load(&state_shared.has_work)) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}

atomic_fetch_sub(&state_shared.n_ready, 1);

while (atomic_load(&state_shared.n_ready) != 0) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}
while (atomic_load(&state_shared.flag) != 0) {}
}

// FINALIZE
if (node->n_tasks > 1) {
if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
atomic_store(&state_shared.has_work, false);
}

while (atomic_load(&state_shared.has_work)) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}

// launch thread pool
for (int j = 0; j < n_threads - 1; j++) {
workers[j].params = (struct ggml_compute_params) {
.type = GGML_TASK_FINALIZE,
.ith = j + 1,
.nth = node->n_tasks,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL,
};
for (int j = 0; j < node->n_tasks-1; j++) {
workers[j].params.type = GGML_TASK_FINALIZE;
workers[j].node = node;
}

atomic_fetch_sub(&state_shared.n_ready, 1);

while (atomic_load(&state_shared.n_ready) > 0) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}

atomic_store(&state_shared.has_work, true);
atomic_store(&state_shared.flag, node->n_tasks-1);
}

params.type = GGML_TASK_FINALIZE;
ggml_compute_forward(&params, node);

// wait for thread pool
if (node->n_tasks > 1) {
if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
atomic_store(&state_shared.has_work, false);
}

while (atomic_load(&state_shared.has_work)) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}

atomic_fetch_sub(&state_shared.n_ready, 1);

while (atomic_load(&state_shared.n_ready) != 0) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}
while (atomic_load(&state_shared.flag) != 0) {}
}

// performance stats (node)
Expand All @@ -9700,16 +9595,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)

// join thread pool
if (n_threads > 1) {
atomic_store(&state_shared.stop, true);
atomic_store(&state_shared.has_work, true);
atomic_store(&state_shared.flag, -1);

for (int j = 0; j < n_threads - 1; j++) {
int rc = ggml_thread_join(workers[j].thrd, NULL);
GGML_ASSERT(rc == 0);
UNUSED(rc);
}

ggml_lock_destroy(&state_shared.spin);
}

// performance stats (graph)
Expand Down