Skip to content

Commit 0fe4b00

Browse files
committed
llama : allow to initialize backend with NUMA support
1 parent 8f98035 commit 0fe4b00

File tree

8 files changed

+30
-19
lines changed

8 files changed

+30
-19
lines changed

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ int main(int argc, char ** argv) {
3535
params.prompt = gpt_random_prompt(rng);
3636
}
3737

38-
llama_init_backend();
38+
llama_init_backend(params.numa);
3939

4040
llama_model * model;
4141
llama_context * ctx;

examples/main/main.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
#include "common.h"
77
#include "llama.h"
8-
#include "ggml.h"
98
#include "build-info.h"
109

1110
#include <cassert>
@@ -106,10 +105,7 @@ int main(int argc, char ** argv) {
106105
params.prompt = gpt_random_prompt(rng);
107106
}
108107

109-
llama_init_backend();
110-
if (params.numa) {
111-
ggml_numa_init();
112-
}
108+
llama_init_backend(params.numa);
113109

114110
llama_model * model;
115111
llama_context * ctx;

examples/perplexity/perplexity.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ int main(int argc, char ** argv) {
147147
params.prompt = gpt_random_prompt(rng);
148148
}
149149

150-
llama_init_backend();
150+
llama_init_backend(params.numa);
151151

152152
llama_model * model;
153153
llama_context * ctx;

examples/quantize/quantize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
180180
usage(argv[0]);
181181
}
182182

183-
llama_init_backend();
183+
llama_init_backend(false);
184184

185185
// parse command line arguments
186186
const std::string fname_inp = argv[arg_idx];

examples/simple/simple.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ int main(int argc, char ** argv)
6666
// Init LLM :
6767
//---------------------------------
6868

69-
llama_init_backend();
69+
llama_init_backend(params.numa);
7070

7171
llama_model * model;
7272
llama_context * ctx;

ggml.c

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3879,14 +3879,12 @@ struct ggml_context_container {
38793879
#define GGML_NUMA_MAX_NODES 8
38803880
#define GGML_NUMA_MAX_CPUS 512
38813881

3882-
struct ggml_numa_node
3883-
{
3882+
struct ggml_numa_node {
38843883
uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node
38853884
uint32_t n_cpus;
38863885
};
38873886

3888-
struct ggml_numa_nodes
3889-
{
3887+
struct ggml_numa_nodes {
38903888
struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES];
38913889
uint32_t n_nodes;
38923890
uint32_t total_cpus; // hardware threads on system
@@ -3923,32 +3921,41 @@ inline static void ggml_critical_section_end(void) {
39233921
atomic_fetch_sub(&g_state_barrier, 1);
39243922
}
39253923

3926-
void ggml_numa_init(void)
3927-
{
3928-
if (g_state.numa.n_nodes > 0) { return; }
3924+
void ggml_numa_init(void) {
3925+
if (g_state.numa.n_nodes > 0) {
3926+
fprintf(stderr, "ggml_numa_init: NUMA already initialized\n");
3927+
3928+
return;
3929+
}
3930+
39293931
#ifdef __linux__
39303932
struct stat st;
39313933
char path[256];
39323934
int rv;
3935+
39333936
// enumerate nodes
39343937
while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) {
39353938
rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes);
39363939
GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
39373940
if (stat(path, &st) != 0) { break; }
39383941
++g_state.numa.n_nodes;
39393942
}
3943+
39403944
// enumerate CPUs
39413945
while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) {
39423946
rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus);
39433947
GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
39443948
if (stat(path, &st) != 0) { break; }
39453949
++g_state.numa.total_cpus;
39463950
}
3951+
39473952
GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus);
3953+
39483954
if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1) {
39493955
g_state.numa.n_nodes = 0;
39503956
return;
39513957
}
3958+
39523959
for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) {
39533960
struct ggml_numa_node * node = &g_state.numa.nodes[n];
39543961
GGML_PRINT_DEBUG("CPUs on node %u:", n);
@@ -3963,6 +3970,7 @@ void ggml_numa_init(void)
39633970
}
39643971
GGML_PRINT_DEBUG("\n");
39653972
}
3973+
39663974
if (ggml_is_numa()) {
39673975
FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r");
39683976
if (fptr != NULL) {
@@ -3978,7 +3986,9 @@ void ggml_numa_init(void)
39783986
#endif
39793987
}
39803988

3981-
bool ggml_is_numa(void) { return g_state.numa.n_nodes > 1; }
3989+
bool ggml_is_numa(void) {
3990+
return g_state.numa.n_nodes > 1;
3991+
}
39823992

39833993
////////////////////////////////////////////////////////////////////////////////
39843994

llama.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,7 @@ bool llama_mlock_supported() {
977977
return llama_mlock::SUPPORTED;
978978
}
979979

980-
void llama_init_backend() {
980+
void llama_init_backend(bool numa) {
981981
ggml_time_init();
982982

983983
// needed to initialize f16 tables
@@ -986,6 +986,10 @@ void llama_init_backend() {
986986
struct ggml_context * ctx = ggml_init(params);
987987
ggml_free(ctx);
988988
}
989+
990+
if (numa) {
991+
ggml_numa_init();
992+
}
989993
}
990994

991995
int64_t llama_time_us() {

llama.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,9 @@ extern "C" {
140140

141141
// TODO: not great API - very likely to change
142142
// Initialize the llama + ggml backend
143+
// If numa is true, use NUMA optimizations
143144
// Call once at the start of the program
144-
LLAMA_API void llama_init_backend();
145+
LLAMA_API void llama_init_backend(bool numa);
145146

146147
LLAMA_API int64_t llama_time_us();
147148

0 commit comments

Comments
 (0)