@@ -1210,8 +1210,7 @@ struct sql_printer : public printer {
1210
1210
}
1211
1211
};
1212
1212
1213
- static void test_prompt (llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
1214
- llama_set_n_threads (ctx, n_threads, n_threads);
1213
+ static void test_prompt (llama_context * ctx, int n_prompt, int n_past, int n_batch) {
1215
1214
1216
1215
const llama_model * model = llama_get_model (ctx);
1217
1216
const int32_t n_vocab = llama_n_vocab (model);
@@ -1233,9 +1232,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
1233
1232
llama_synchronize (ctx);
1234
1233
}
1235
1234
1236
- static void test_gen (llama_context * ctx, int n_gen, int n_past, int n_threads) {
1237
- llama_set_n_threads (ctx, n_threads, n_threads);
1238
-
1235
+ static void test_gen (llama_context * ctx, int n_gen, int n_past) {
1239
1236
const llama_model * model = llama_get_model (ctx);
1240
1237
const int32_t n_vocab = llama_n_vocab (model);
1241
1238
@@ -1332,13 +1329,31 @@ int main(int argc, char ** argv) {
1332
1329
1333
1330
llama_kv_cache_clear (ctx);
1334
1331
1332
+ struct ggml_threadpool_params tpp;
1333
+ tpp.n_threads = t.n_threads ;
1334
+
1335
+ // TODO: expose these via cli opts
1336
+ tpp.mask_specified = false ;
1337
+ tpp.strict_cpu = false ;
1338
+ tpp.prio = 1 ;
1339
+ tpp.poll = false ;
1340
+
1341
+ struct ggml_compute_threadpool * threadpool = ggml_create_threadpool (&tpp);
1342
+ if (!threadpool) {
1343
+ LOG_TEE (" %s: threadpool create failed : n_threads %d\n " , __func__, tpp.n_threads );
1344
+ exit (1 );
1345
+ }
1346
+
1347
+ llama_set_n_threads (ctx, t.n_threads , t.n_threads );
1348
+ llama_attach_threadpool (ctx, threadpool);
1349
+
1335
1350
// warmup run
1336
1351
if (t.n_prompt > 0 ) {
1337
- // test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads );
1338
- test_prompt (ctx, t.n_prompt , 0 , t.n_batch , t. n_threads );
1352
+ // test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch);
1353
+ test_prompt (ctx, t.n_prompt , 0 , t.n_batch );
1339
1354
}
1340
1355
if (t.n_gen > 0 ) {
1341
- test_gen (ctx, 1 , 0 , t. n_threads );
1356
+ test_gen (ctx, 1 , 0 );
1342
1357
}
1343
1358
1344
1359
for (int i = 0 ; i < params.reps ; i++) {
@@ -1347,10 +1362,10 @@ int main(int argc, char ** argv) {
1347
1362
uint64_t t_start = get_time_ns ();
1348
1363
1349
1364
if (t.n_prompt > 0 ) {
1350
- test_prompt (ctx, t.n_prompt , 0 , t.n_batch , t. n_threads );
1365
+ test_prompt (ctx, t.n_prompt , 0 , t.n_batch );
1351
1366
}
1352
1367
if (t.n_gen > 0 ) {
1353
- test_gen (ctx, t.n_gen , t.n_prompt , t. n_threads );
1368
+ test_gen (ctx, t.n_gen , t.n_prompt );
1354
1369
}
1355
1370
1356
1371
uint64_t t_ns = get_time_ns () - t_start;
@@ -1362,6 +1377,8 @@ int main(int argc, char ** argv) {
1362
1377
llama_print_timings (ctx);
1363
1378
1364
1379
llama_free (ctx);
1380
+
1381
+ ggml_release_threadpool (threadpool);
1365
1382
}
1366
1383
1367
1384
llama_free_model (lmodel);
0 commit comments