@@ -8954,6 +8954,19 @@ typedef pthread_t ggml_thread_t;
8954
8954
#define ggml_thread_create pthread_create
8955
8955
#define ggml_thread_join pthread_join
8956
8956
8957
+ typedef pthread_mutex_t ggml_mutex_t ;
8958
+ typedef pthread_cond_t ggml_cond_t ;
8959
+
8960
+ #define ggml_mutex_init pthread_mutex_init
8961
+ #define ggml_mutex_destroy pthread_mutex_destroy
8962
+ #define ggml_cond_init pthread_cond_init
8963
+ #define ggml_cond_destroy pthread_cond_destroy
8964
+
8965
+ #define ggml_mutex_lock pthread_mutex_lock
8966
+ #define ggml_mutex_unlock pthread_mutex_unlock
8967
+ #define ggml_cond_broadcast pthread_cond_broadcast
8968
+ #define ggml_cond_wait pthread_cond_wait
8969
+
8957
8970
#else
8958
8971
8959
8972
//typedef pthread_spinlock_t ggml_lock_t;
@@ -8977,17 +8990,31 @@ typedef pthread_t ggml_thread_t;
8977
8990
#define ggml_thread_create pthread_create
8978
8991
#define ggml_thread_join pthread_join
8979
8992
8993
+ typedef pthread_mutex_t ggml_mutex_t ;
8994
+ typedef pthread_cond_t ggml_cond_t ;
8995
+
8996
+ #define ggml_mutex_init pthread_mutex_init
8997
+ #define ggml_mutex_destroy pthread_mutex_destroy
8998
+ #define ggml_cond_init pthread_cond_init
8999
+ #define ggml_cond_destroy pthread_cond_destroy
9000
+
9001
+ #define ggml_mutex_lock pthread_mutex_lock
9002
+ #define ggml_mutex_unlock pthread_mutex_unlock
9003
+ #define ggml_cond_broadcast pthread_cond_broadcast
9004
+ #define ggml_cond_wait pthread_cond_wait
9005
+
8980
9006
#endif
8981
9007
8982
9008
struct ggml_compute_state_shared {
8983
- ggml_lock_t spin ;
8984
9009
8985
9010
int n_threads ;
8986
9011
8987
9012
// synchronization primitives
8988
- atomic_int n_ready ;
8989
- atomic_bool has_work ;
8990
- atomic_bool stop ; // stop all threads
9013
+ int n_ready ;
9014
+ bool has_work ;
9015
+ bool stop ; // stop all threads
9016
+ ggml_mutex_t mutex ;
9017
+ ggml_cond_t cond ;
8991
9018
};
8992
9019
8993
9020
struct ggml_compute_state {
@@ -9003,43 +9030,57 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
9003
9030
struct ggml_compute_state * state = (struct ggml_compute_state * ) data ;
9004
9031
9005
9032
const int n_threads = state -> shared -> n_threads ;
9006
-
9007
9033
while (true) {
9008
- if (atomic_fetch_add (& state -> shared -> n_ready , 1 ) == n_threads - 1 ) {
9009
- atomic_store (& state -> shared -> has_work , false);
9034
+ ggml_mutex_lock (& state -> shared -> mutex );
9035
+ if (state -> shared -> n_ready ++ == n_threads - 1 ) {
9036
+ state -> shared -> has_work = false;
9037
+ ggml_cond_broadcast (& state -> shared -> cond );
9010
9038
} else {
9011
- while (atomic_load (& state -> shared -> has_work )) {
9012
- if (atomic_load (& state -> shared -> stop )) {
9039
+ while (state -> shared -> has_work ) {
9040
+ if (state -> shared -> stop ) {
9041
+ ggml_mutex_unlock (& state -> shared -> mutex );
9042
+ return 0 ;
9043
+ }
9044
+ ggml_cond_wait (& state -> shared -> cond , & state -> shared -> mutex );
9045
+ if (state -> shared -> stop ) {
9046
+ ggml_mutex_unlock (& state -> shared -> mutex );
9013
9047
return 0 ;
9014
9048
}
9015
- ggml_lock_lock (& state -> shared -> spin );
9016
- ggml_lock_unlock (& state -> shared -> spin );
9017
9049
}
9018
9050
}
9051
+ ggml_mutex_unlock (& state -> shared -> mutex );
9019
9052
9020
- atomic_fetch_sub (& state -> shared -> n_ready , 1 );
9053
+ ggml_mutex_lock (& state -> shared -> mutex );
9054
+ state -> shared -> n_ready -- ;
9055
+ ggml_cond_broadcast (& state -> shared -> cond );
9056
+ ggml_mutex_unlock (& state -> shared -> mutex );
9021
9057
9022
9058
// wait for work
9023
- while (!atomic_load (& state -> shared -> has_work )) {
9024
- if (atomic_load (& state -> shared -> stop )) {
9025
- return 0 ;
9059
+ ggml_mutex_lock (& state -> shared -> mutex );
9060
+ while (!state -> shared -> has_work && !state -> shared -> stop ) {
9061
+ if (state -> shared -> stop ) {
9062
+ ggml_mutex_unlock (& state -> shared -> mutex );
9063
+ return 0 ;
9026
9064
}
9027
- ggml_lock_lock (& state -> shared -> spin );
9028
- ggml_lock_unlock (& state -> shared -> spin );
9065
+ ggml_cond_wait (& state -> shared -> cond , & state -> shared -> mutex );
9029
9066
}
9067
+ ggml_mutex_unlock (& state -> shared -> mutex );
9030
9068
9031
9069
// check if we should stop
9032
- if (atomic_load (& state -> shared -> stop )) {
9070
+ ggml_mutex_lock (& state -> shared -> mutex );
9071
+ if (state -> shared -> stop ) {
9072
+ ggml_mutex_unlock (& state -> shared -> mutex );
9033
9073
break ;
9034
9074
}
9075
+ ggml_mutex_unlock (& state -> shared -> mutex );
9035
9076
9036
9077
if (state -> node ) {
9037
9078
if (state -> params .ith < state -> params .nth ) {
9038
9079
ggml_compute_forward (& state -> params , state -> node );
9039
9080
}
9040
-
9041
9081
state -> node = NULL ;
9042
9082
} else {
9083
+ ggml_mutex_unlock (& state -> shared -> mutex );
9043
9084
break ;
9044
9085
}
9045
9086
}
@@ -9051,19 +9092,32 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9051
9092
const int n_threads = cgraph -> n_threads ;
9052
9093
9053
9094
struct ggml_compute_state_shared state_shared = {
9054
- /*.spin =*/ GGML_LOCK_INITIALIZER ,
9055
9095
/*.n_threads =*/ n_threads ,
9056
9096
/*.n_ready =*/ 0 ,
9057
9097
/*.has_work =*/ false,
9058
9098
/*.stop =*/ false,
9099
+ /*.mutex =*/ {0 },
9100
+ /*.cond =*/ {0 },
9059
9101
};
9102
+ {
9103
+ int rc = ggml_mutex_init (& state_shared .mutex , NULL );
9104
+ GGML_ASSERT (rc == 0 );
9105
+ UNUSED (rc );
9106
+ }
9107
+ {
9108
+ int rc = ggml_cond_init (& state_shared .cond , NULL );
9109
+ GGML_ASSERT (rc == 0 );
9110
+ UNUSED (rc );
9111
+ }
9060
9112
struct ggml_compute_state * workers = n_threads > 1 ? alloca (sizeof (struct ggml_compute_state )* (n_threads - 1 )) : NULL ;
9061
9113
9062
9114
// create thread pool
9063
9115
if (n_threads > 1 ) {
9064
- ggml_lock_init (& state_shared .spin );
9065
9116
9066
- atomic_store (& state_shared .has_work , true);
9117
+ ggml_mutex_lock (& state_shared .mutex );
9118
+ state_shared .has_work = true;
9119
+ ggml_cond_broadcast (& state_shared .cond );
9120
+ ggml_mutex_unlock (& state_shared .mutex );
9067
9121
9068
9122
for (int j = 0 ; j < n_threads - 1 ; j ++ ) {
9069
9123
workers [j ] = (struct ggml_compute_state ) {
@@ -9319,14 +9373,18 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9319
9373
9320
9374
// COMPUTE
9321
9375
if (node -> n_tasks > 1 ) {
9322
- if (atomic_fetch_add (& state_shared .n_ready , 1 ) == n_threads - 1 ) {
9323
- atomic_store (& state_shared .has_work , false);
9376
+ ggml_mutex_lock (& state_shared .mutex );
9377
+ if (state_shared .n_ready ++ == n_threads - 1 ) {
9378
+ state_shared .has_work = false;
9379
+ ggml_cond_broadcast (& state_shared .cond );
9324
9380
}
9381
+ ggml_mutex_unlock (& state_shared .mutex );
9325
9382
9326
- while ( atomic_load ( & state_shared .has_work )) {
9327
- ggml_lock_lock ( & state_shared .spin );
9328
- ggml_lock_unlock (& state_shared .spin );
9383
+ ggml_mutex_lock ( & state_shared .mutex );
9384
+ while ( state_shared .has_work ) {
9385
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9329
9386
}
9387
+ ggml_mutex_unlock (& state_shared .mutex );
9330
9388
9331
9389
// launch thread pool
9332
9390
for (int j = 0 ; j < n_threads - 1 ; j ++ ) {
@@ -9340,48 +9398,68 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9340
9398
workers [j ].node = node ;
9341
9399
}
9342
9400
9343
- atomic_fetch_sub (& state_shared .n_ready , 1 );
9401
+ ggml_mutex_lock (& state_shared .mutex );
9402
+ state_shared .n_ready -- ;
9403
+ ggml_cond_broadcast (& state_shared .cond );
9404
+ ggml_mutex_unlock (& state_shared .mutex );
9344
9405
9345
- while ( atomic_load ( & state_shared .n_ready ) > 0 ) {
9346
- ggml_lock_lock ( & state_shared .spin );
9347
- ggml_lock_unlock (& state_shared .spin );
9406
+ ggml_mutex_lock ( & state_shared .mutex );
9407
+ while ( state_shared .n_ready > 0 ) {
9408
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9348
9409
}
9410
+ ggml_mutex_unlock (& state_shared .mutex );
9349
9411
9350
- atomic_store (& state_shared .has_work , true);
9412
+
9413
+ ggml_mutex_lock (& state_shared .mutex );
9414
+ state_shared .has_work = true;
9415
+ ggml_cond_broadcast (& state_shared .cond );
9416
+ ggml_mutex_unlock (& state_shared .mutex );
9351
9417
}
9352
9418
9353
9419
params .type = GGML_TASK_COMPUTE ;
9354
9420
ggml_compute_forward (& params , node );
9355
9421
9356
9422
// wait for thread pool
9357
9423
if (node -> n_tasks > 1 ) {
9358
- if (atomic_fetch_add (& state_shared .n_ready , 1 ) == n_threads - 1 ) {
9359
- atomic_store (& state_shared .has_work , false);
9424
+ ggml_mutex_lock (& state_shared .mutex );
9425
+ if (state_shared .n_ready ++ == n_threads - 1 ) {
9426
+ state_shared .has_work = false;
9427
+ ggml_cond_broadcast (& state_shared .cond );
9360
9428
}
9429
+ ggml_mutex_unlock (& state_shared .mutex );
9361
9430
9362
- while ( atomic_load ( & state_shared .has_work )) {
9363
- ggml_lock_lock ( & state_shared .spin );
9364
- ggml_lock_unlock (& state_shared .spin );
9431
+ ggml_mutex_lock ( & state_shared .mutex );
9432
+ while ( state_shared .has_work ) {
9433
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9365
9434
}
9435
+ ggml_mutex_unlock (& state_shared .mutex );
9366
9436
9367
- atomic_fetch_sub (& state_shared .n_ready , 1 );
9437
+ ggml_mutex_lock (& state_shared .mutex );
9438
+ state_shared .n_ready -- ;
9439
+ ggml_cond_broadcast (& state_shared .cond );
9440
+ ggml_mutex_unlock (& state_shared .mutex );
9368
9441
9369
- while ( atomic_load ( & state_shared .n_ready ) != 0 ) {
9370
- ggml_lock_lock ( & state_shared .spin );
9371
- ggml_lock_unlock (& state_shared .spin );
9442
+ ggml_mutex_lock ( & state_shared .mutex );
9443
+ while ( state_shared .n_ready != 0 ) {
9444
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9372
9445
}
9446
+ ggml_mutex_unlock (& state_shared .mutex );
9373
9447
}
9374
9448
9375
9449
// FINALIZE
9376
9450
if (node -> n_tasks > 1 ) {
9377
- if (atomic_fetch_add (& state_shared .n_ready , 1 ) == n_threads - 1 ) {
9378
- atomic_store (& state_shared .has_work , false);
9451
+ ggml_mutex_lock (& state_shared .mutex );
9452
+ if (state_shared .n_ready ++ == n_threads - 1 ) {
9453
+ state_shared .has_work = false;
9454
+ ggml_cond_broadcast (& state_shared .cond );
9379
9455
}
9456
+ ggml_mutex_unlock (& state_shared .mutex );
9380
9457
9381
- while ( atomic_load ( & state_shared .has_work )) {
9382
- ggml_lock_lock ( & state_shared .spin );
9383
- ggml_lock_unlock (& state_shared .spin );
9458
+ ggml_mutex_lock ( & state_shared .mutex );
9459
+ while ( state_shared .has_work ) {
9460
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9384
9461
}
9462
+ ggml_mutex_unlock (& state_shared .mutex );
9385
9463
9386
9464
// launch thread pool
9387
9465
for (int j = 0 ; j < n_threads - 1 ; j ++ ) {
@@ -9395,36 +9473,51 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9395
9473
workers [j ].node = node ;
9396
9474
}
9397
9475
9398
- atomic_fetch_sub (& state_shared .n_ready , 1 );
9476
+ ggml_mutex_lock (& state_shared .mutex );
9477
+ state_shared .n_ready -= 1 ;
9478
+ ggml_cond_broadcast (& state_shared .cond );
9479
+ ggml_mutex_unlock (& state_shared .mutex );
9399
9480
9400
- while ( atomic_load ( & state_shared .n_ready ) > 0 ) {
9401
- ggml_lock_lock ( & state_shared .spin );
9402
- ggml_lock_unlock (& state_shared .spin );
9481
+ ggml_mutex_lock ( & state_shared .mutex );
9482
+ while ( state_shared .n_ready > 0 ) {
9483
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9403
9484
}
9485
+ ggml_mutex_unlock (& state_shared .mutex );
9404
9486
9405
- atomic_store (& state_shared .has_work , true);
9487
+ ggml_mutex_lock (& state_shared .mutex );
9488
+ state_shared .has_work = true;
9489
+ ggml_cond_broadcast (& state_shared .cond );
9490
+ ggml_mutex_unlock (& state_shared .mutex );
9406
9491
}
9407
9492
9408
9493
params .type = GGML_TASK_FINALIZE ;
9409
9494
ggml_compute_forward (& params , node );
9410
9495
9411
9496
// wait for thread pool
9412
9497
if (node -> n_tasks > 1 ) {
9413
- if (atomic_fetch_add (& state_shared .n_ready , 1 ) == n_threads - 1 ) {
9414
- atomic_store (& state_shared .has_work , false);
9498
+ ggml_mutex_lock (& state_shared .mutex );
9499
+ if (state_shared .n_ready ++ == n_threads - 1 ) {
9500
+ state_shared .has_work = false;
9501
+ ggml_cond_broadcast (& state_shared .cond );
9415
9502
}
9503
+ ggml_mutex_unlock (& state_shared .mutex );
9416
9504
9417
- while ( atomic_load ( & state_shared .has_work )) {
9418
- ggml_lock_lock ( & state_shared .spin );
9419
- ggml_lock_unlock (& state_shared .spin );
9505
+ ggml_mutex_lock ( & state_shared .mutex );
9506
+ while ( state_shared .has_work ) {
9507
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9420
9508
}
9509
+ ggml_mutex_unlock (& state_shared .mutex );
9421
9510
9422
- atomic_fetch_sub (& state_shared .n_ready , 1 );
9511
+ ggml_mutex_lock (& state_shared .mutex );
9512
+ state_shared .n_ready -- ;
9513
+ ggml_cond_broadcast (& state_shared .cond );
9514
+ ggml_mutex_unlock (& state_shared .mutex );
9423
9515
9424
- while ( atomic_load ( & state_shared .n_ready ) != 0 ) {
9425
- ggml_lock_lock ( & state_shared .spin );
9426
- ggml_lock_unlock (& state_shared .spin );
9516
+ ggml_mutex_lock ( & state_shared .mutex );
9517
+ while ( state_shared .n_ready != 0 ) {
9518
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9427
9519
}
9520
+ ggml_mutex_unlock (& state_shared .mutex );
9428
9521
}
9429
9522
9430
9523
// performance stats (node)
@@ -9440,16 +9533,19 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9440
9533
9441
9534
// join thread pool
9442
9535
if (n_threads > 1 ) {
9443
- atomic_store (& state_shared .stop , true);
9444
- atomic_store (& state_shared .has_work , true);
9445
-
9536
+ ggml_mutex_lock (& state_shared .mutex );
9537
+ state_shared .stop = true;
9538
+ state_shared .has_work = true;
9539
+ ggml_cond_broadcast (& state_shared .cond );
9540
+ ggml_mutex_unlock (& state_shared .mutex );
9446
9541
for (int j = 0 ; j < n_threads - 1 ; j ++ ) {
9447
9542
int rc = ggml_thread_join (workers [j ].thrd , NULL );
9448
9543
GGML_ASSERT (rc == 0 );
9449
9544
UNUSED (rc );
9450
9545
}
9451
9546
9452
- ggml_lock_destroy (& state_shared .spin );
9547
+ ggml_cond_destroy (& state_shared .cond );
9548
+ ggml_mutex_destroy (& state_shared .mutex );
9453
9549
}
9454
9550
9455
9551
// performance stats (graph)
0 commit comments