Skip to content

ggml: Support OpenMP for multi-thread processing #7606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,22 @@ jobs:

- name: Build
id: cmake_build
if: ${{ matrix.sanitizer != 'THREAD' }}
run: |
mkdir build
cd build
cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
cmake --build . --config ${{ matrix.build_type }} -j $(nproc)

- name: Build (no OpenMP)
id: cmake_build_no_openmp
if: ${{ matrix.sanitizer == 'THREAD' }}
run: |
mkdir build
cd build
cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} -DLLAMA_OPENMP=OFF
cmake --build . --config ${{ matrix.build_type }} -j $(nproc)

- name: Test
id: cmake_test
run: |
Expand Down
12 changes: 12 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ set(LLAMA_METAL_MACOSX_VERSION_MIN "" CACHE STRING
set(LLAMA_METAL_STD "" CACHE STRING "llama: metal standard version (-std flag)")
option(LLAMA_KOMPUTE "llama: use Kompute" OFF)
option(LLAMA_RPC "llama: use RPC" OFF)
option(LLAMA_OPENMP "llama: use OpenMP" ON)
option(LLAMA_SYCL "llama: use SYCL" OFF)
option(LLAMA_SYCL_F16 "llama: use 16 bit floats for sycl calculations" OFF)
set(LLAMA_SYCL_TARGET "INTEL" CACHE STRING "llama: sycl target device")
Expand Down Expand Up @@ -295,6 +296,17 @@ if (LLAMA_METAL)
)
endif()

if (LLAMA_OPENMP)
find_package(OpenMP)
if (OpenMP_FOUND)
message(STATUS "OpenMP found")
add_compile_definitions(GGML_USE_OPENMP)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
else()
message(WARNING "OpenMP not found")
endif()
endif()

if (LLAMA_BLAS)
if (LLAMA_STATIC)
set(BLA_STATIC ON)
Expand Down
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ ifeq ($(UNAME_S),Darwin)
LLAMA_METAL := 1
endif

LLAMA_NO_OPENMP := 1

ifneq ($(UNAME_P),arm)
SYSCTL_M := $(shell sysctl -n hw.optional.arm64 2>/dev/null)
ifeq ($(SYSCTL_M),1)
Expand Down Expand Up @@ -400,6 +402,12 @@ ifndef LLAMA_NO_ACCELERATE
endif
endif # LLAMA_NO_ACCELERATE

ifndef LLAMA_NO_OPENMP
MK_CPPFLAGS += -DGGML_USE_OPENMP
MK_CFLAGS += -fopenmp
MK_CXXFLAGS += -fopenmp
endif # LLAMA_NO_OPENMP

ifdef LLAMA_OPENBLAS
MK_CPPFLAGS += -DGGML_USE_OPENBLAS $(shell pkg-config --cflags-only-I openblas)
MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas)
Expand Down
111 changes: 73 additions & 38 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ggml-quants.h"
#include "ggml.h"


#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
Expand All @@ -28,6 +29,10 @@
#include <syscall.h>
#endif

#ifdef GGML_USE_OPENMP
#include <omp.h>
#endif

#ifdef GGML_USE_METAL
#include <unistd.h>
#endif
Expand Down Expand Up @@ -1746,7 +1751,7 @@ struct ggml_compute_state_shared {
int64_t perf_node_start_cycles;
int64_t perf_node_start_time_us;

const int n_threads;
int n_threads;

// synchronization primitives
atomic_int n_active; // num active threads
Expand Down Expand Up @@ -19661,6 +19666,59 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
return cplan;
}

static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state * workers, int n_threads) {
enum ggml_status compute_status = GGML_STATUS_SUCCESS;

#ifdef GGML_USE_OPENMP
if (n_threads > 1) {
#pragma omp parallel num_threads(n_threads)
{
#pragma omp single
{
// update the number of threads from the actual number of threads that we got from OpenMP
n_threads = omp_get_num_threads();
workers[0].shared->n_threads = n_threads;
workers[0].shared->n_active = n_threads;
}
ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
}
} else {
ggml_graph_compute_thread(&workers[0]);
}
#else
// create thread pool
if (n_threads > 1) {
for (int j = 1; j < n_threads; ++j) {
const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
GGML_ASSERT(rc == 0);
UNUSED(rc);
}
}

// this is a work thread too
ggml_graph_compute_thread(&workers[0]);

// join or kill thread pool
if (n_threads > 1) {
for (int j = 1; j < n_threads; j++) {
const int rc = ggml_thread_join(workers[j].thrd, NULL);
GGML_ASSERT(rc == 0);
UNUSED(rc);
}
}
#endif
// don't leave affinity set on the main thread
clear_numa_thread_affinity();

for (int j = 0; j < n_threads; j++) {
if (workers[j].ec != GGML_STATUS_SUCCESS) {
compute_status = workers[j].ec;
break;
}
}
return compute_status;
}

enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
{
GGML_ASSERT(cplan);
Expand All @@ -19671,7 +19729,11 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
}
}

const int n_threads = cplan->n_threads;
int n_threads = cplan->n_threads;

#if defined(GGML_USE_OPENMP)
n_threads = MIN(n_threads, omp_get_max_threads());
#endif

struct ggml_compute_state_shared state_shared = {
/*.cgraph =*/ cgraph,
Expand All @@ -19687,47 +19749,20 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
/*.current_chunk; =*/ 0,
};
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);

// create thread pool
if (n_threads > 1) {
for (int j = 1; j < n_threads; ++j) {
workers[j] = (struct ggml_compute_state) {
.thrd = 0,
.ith = j,
.shared = &state_shared,
.ec = GGML_STATUS_SUCCESS,
};

const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
GGML_ASSERT(rc == 0);
UNUSED(rc);
}
}

workers[0].ith = 0;
workers[0].shared = &state_shared;
workers[0].ec = GGML_STATUS_SUCCESS;

const int64_t perf_start_cycles = ggml_perf_cycles();
const int64_t perf_start_time_us = ggml_perf_time_us();

// this is a work thread too
ggml_graph_compute_thread(&workers[0]);
enum ggml_status compute_status = workers[0].ec;

// don't leave affinity set on the main thread
clear_numa_thread_affinity();

// join or kill thread pool
if (n_threads > 1) {
for (int j = 1; j < n_threads; j++) {
const int rc = ggml_thread_join(workers[j].thrd, NULL);
GGML_ASSERT(rc == 0);
if (workers[j].ec != GGML_STATUS_SUCCESS)
compute_status = workers[j].ec;
}
for (int j = 0; j < n_threads; ++j) {
workers[j] = (struct ggml_compute_state) {
.thrd = 0,
.ith = j,
.shared = &state_shared,
.ec = GGML_STATUS_SUCCESS,
};
}

enum ggml_status compute_status = ggml_graph_compute_parallel(workers, n_threads);

// performance stats (graph)
{
int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles;
Expand Down
Loading