@@ -9112,6 +9112,19 @@ typedef pthread_t ggml_thread_t;
9112
9112
#define ggml_thread_create pthread_create
9113
9113
#define ggml_thread_join pthread_join
9114
9114
9115
+ typedef pthread_mutex_t ggml_mutex_t ;
9116
+ typedef pthread_cond_t ggml_cond_t ;
9117
+
9118
+ #define ggml_mutex_init pthread_mutex_init
9119
+ #define ggml_mutex_destroy pthread_mutex_destroy
9120
+ #define ggml_cond_init pthread_cond_init
9121
+ #define ggml_cond_destroy pthread_cond_destroy
9122
+
9123
+ #define ggml_mutex_lock pthread_mutex_lock
9124
+ #define ggml_mutex_unlock pthread_mutex_unlock
9125
+ #define ggml_cond_broadcast pthread_cond_broadcast
9126
+ #define ggml_cond_wait pthread_cond_wait
9127
+
9115
9128
#else
9116
9129
9117
9130
//typedef pthread_spinlock_t ggml_lock_t;
@@ -9135,17 +9148,31 @@ typedef pthread_t ggml_thread_t;
9135
9148
#define ggml_thread_create pthread_create
9136
9149
#define ggml_thread_join pthread_join
9137
9150
9151
+ typedef pthread_mutex_t ggml_mutex_t ;
9152
+ typedef pthread_cond_t ggml_cond_t ;
9153
+
9154
+ #define ggml_mutex_init pthread_mutex_init
9155
+ #define ggml_mutex_destroy pthread_mutex_destroy
9156
+ #define ggml_cond_init pthread_cond_init
9157
+ #define ggml_cond_destroy pthread_cond_destroy
9158
+
9159
+ #define ggml_mutex_lock pthread_mutex_lock
9160
+ #define ggml_mutex_unlock pthread_mutex_unlock
9161
+ #define ggml_cond_broadcast pthread_cond_broadcast
9162
+ #define ggml_cond_wait pthread_cond_wait
9163
+
9138
9164
#endif
9139
9165
9140
9166
struct ggml_compute_state_shared {
9141
- ggml_lock_t spin ;
9142
9167
9143
9168
int n_threads ;
9144
9169
9145
9170
// synchronization primitives
9146
- atomic_int n_ready ;
9147
- atomic_bool has_work ;
9148
- atomic_bool stop ; // stop all threads
9171
+ int n_ready ;
9172
+ bool has_work ;
9173
+ bool stop ; // stop all threads
9174
+ ggml_mutex_t mutex ;
9175
+ ggml_cond_t cond ;
9149
9176
};
9150
9177
9151
9178
struct ggml_compute_state {
@@ -9161,43 +9188,57 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
9161
9188
struct ggml_compute_state * state = (struct ggml_compute_state * ) data ;
9162
9189
9163
9190
const int n_threads = state -> shared -> n_threads ;
9164
-
9165
9191
while (true) {
9166
- if (atomic_fetch_add (& state -> shared -> n_ready , 1 ) == n_threads - 1 ) {
9167
- atomic_store (& state -> shared -> has_work , false);
9192
+ ggml_mutex_lock (& state -> shared -> mutex );
9193
+ if (state -> shared -> n_ready ++ == n_threads - 1 ) {
9194
+ state -> shared -> has_work = false;
9195
+ ggml_cond_broadcast (& state -> shared -> cond );
9168
9196
} else {
9169
- while (atomic_load (& state -> shared -> has_work )) {
9170
- if (atomic_load (& state -> shared -> stop )) {
9197
+ while (state -> shared -> has_work ) {
9198
+ if (state -> shared -> stop ) {
9199
+ ggml_mutex_unlock (& state -> shared -> mutex );
9200
+ return 0 ;
9201
+ }
9202
+ ggml_cond_wait (& state -> shared -> cond , & state -> shared -> mutex );
9203
+ if (state -> shared -> stop ) {
9204
+ ggml_mutex_unlock (& state -> shared -> mutex );
9171
9205
return 0 ;
9172
9206
}
9173
- ggml_lock_lock (& state -> shared -> spin );
9174
- ggml_lock_unlock (& state -> shared -> spin );
9175
9207
}
9176
9208
}
9209
+ ggml_mutex_unlock (& state -> shared -> mutex );
9177
9210
9178
- atomic_fetch_sub (& state -> shared -> n_ready , 1 );
9211
+ ggml_mutex_lock (& state -> shared -> mutex );
9212
+ state -> shared -> n_ready -- ;
9213
+ ggml_cond_broadcast (& state -> shared -> cond );
9214
+ ggml_mutex_unlock (& state -> shared -> mutex );
9179
9215
9180
9216
// wait for work
9181
- while (!atomic_load (& state -> shared -> has_work )) {
9182
- if (atomic_load (& state -> shared -> stop )) {
9183
- return 0 ;
9217
+ ggml_mutex_lock (& state -> shared -> mutex );
9218
+ while (!state -> shared -> has_work && !state -> shared -> stop ) {
9219
+ if (state -> shared -> stop ) {
9220
+ ggml_mutex_unlock (& state -> shared -> mutex );
9221
+ return 0 ;
9184
9222
}
9185
- ggml_lock_lock (& state -> shared -> spin );
9186
- ggml_lock_unlock (& state -> shared -> spin );
9223
+ ggml_cond_wait (& state -> shared -> cond , & state -> shared -> mutex );
9187
9224
}
9225
+ ggml_mutex_unlock (& state -> shared -> mutex );
9188
9226
9189
9227
// check if we should stop
9190
- if (atomic_load (& state -> shared -> stop )) {
9228
+ ggml_mutex_lock (& state -> shared -> mutex );
9229
+ if (state -> shared -> stop ) {
9230
+ ggml_mutex_unlock (& state -> shared -> mutex );
9191
9231
break ;
9192
9232
}
9233
+ ggml_mutex_unlock (& state -> shared -> mutex );
9193
9234
9194
9235
if (state -> node ) {
9195
9236
if (state -> params .ith < state -> params .nth ) {
9196
9237
ggml_compute_forward (& state -> params , state -> node );
9197
9238
}
9198
-
9199
9239
state -> node = NULL ;
9200
9240
} else {
9241
+ ggml_mutex_unlock (& state -> shared -> mutex );
9201
9242
break ;
9202
9243
}
9203
9244
}
@@ -9209,19 +9250,32 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9209
9250
const int n_threads = cgraph -> n_threads ;
9210
9251
9211
9252
struct ggml_compute_state_shared state_shared = {
9212
- /*.spin =*/ GGML_LOCK_INITIALIZER ,
9213
9253
/*.n_threads =*/ n_threads ,
9214
9254
/*.n_ready =*/ 0 ,
9215
9255
/*.has_work =*/ false,
9216
9256
/*.stop =*/ false,
9257
+ /*.mutex =*/ {0 },
9258
+ /*.cond =*/ {0 },
9217
9259
};
9260
+ {
9261
+ int rc = ggml_mutex_init (& state_shared .mutex , NULL );
9262
+ GGML_ASSERT (rc == 0 );
9263
+ UNUSED (rc );
9264
+ }
9265
+ {
9266
+ int rc = ggml_cond_init (& state_shared .cond , NULL );
9267
+ GGML_ASSERT (rc == 0 );
9268
+ UNUSED (rc );
9269
+ }
9218
9270
struct ggml_compute_state * workers = n_threads > 1 ? alloca (sizeof (struct ggml_compute_state )* (n_threads - 1 )) : NULL ;
9219
9271
9220
9272
// create thread pool
9221
9273
if (n_threads > 1 ) {
9222
- ggml_lock_init (& state_shared .spin );
9223
9274
9224
- atomic_store (& state_shared .has_work , true);
9275
+ ggml_mutex_lock (& state_shared .mutex );
9276
+ state_shared .has_work = true;
9277
+ ggml_cond_broadcast (& state_shared .cond );
9278
+ ggml_mutex_unlock (& state_shared .mutex );
9225
9279
9226
9280
for (int j = 0 ; j < n_threads - 1 ; j ++ ) {
9227
9281
workers [j ] = (struct ggml_compute_state ) {
@@ -9477,14 +9531,18 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9477
9531
9478
9532
// COMPUTE
9479
9533
if (node -> n_tasks > 1 ) {
9480
- if (atomic_fetch_add (& state_shared .n_ready , 1 ) == n_threads - 1 ) {
9481
- atomic_store (& state_shared .has_work , false);
9534
+ ggml_mutex_lock (& state_shared .mutex );
9535
+ if (state_shared .n_ready ++ == n_threads - 1 ) {
9536
+ state_shared .has_work = false;
9537
+ ggml_cond_broadcast (& state_shared .cond );
9482
9538
}
9539
+ ggml_mutex_unlock (& state_shared .mutex );
9483
9540
9484
- while ( atomic_load ( & state_shared .has_work )) {
9485
- ggml_lock_lock ( & state_shared .spin );
9486
- ggml_lock_unlock (& state_shared .spin );
9541
+ ggml_mutex_lock ( & state_shared .mutex );
9542
+ while ( state_shared .has_work ) {
9543
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9487
9544
}
9545
+ ggml_mutex_unlock (& state_shared .mutex );
9488
9546
9489
9547
// launch thread pool
9490
9548
for (int j = 0 ; j < n_threads - 1 ; j ++ ) {
@@ -9498,48 +9556,68 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9498
9556
workers [j ].node = node ;
9499
9557
}
9500
9558
9501
- atomic_fetch_sub (& state_shared .n_ready , 1 );
9559
+ ggml_mutex_lock (& state_shared .mutex );
9560
+ state_shared .n_ready -- ;
9561
+ ggml_cond_broadcast (& state_shared .cond );
9562
+ ggml_mutex_unlock (& state_shared .mutex );
9502
9563
9503
- while ( atomic_load ( & state_shared .n_ready ) > 0 ) {
9504
- ggml_lock_lock ( & state_shared .spin );
9505
- ggml_lock_unlock (& state_shared .spin );
9564
+ ggml_mutex_lock ( & state_shared .mutex );
9565
+ while ( state_shared .n_ready > 0 ) {
9566
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9506
9567
}
9568
+ ggml_mutex_unlock (& state_shared .mutex );
9507
9569
9508
- atomic_store (& state_shared .has_work , true);
9570
+
9571
+ ggml_mutex_lock (& state_shared .mutex );
9572
+ state_shared .has_work = true;
9573
+ ggml_cond_broadcast (& state_shared .cond );
9574
+ ggml_mutex_unlock (& state_shared .mutex );
9509
9575
}
9510
9576
9511
9577
params .type = GGML_TASK_COMPUTE ;
9512
9578
ggml_compute_forward (& params , node );
9513
9579
9514
9580
// wait for thread pool
9515
9581
if (node -> n_tasks > 1 ) {
9516
- if (atomic_fetch_add (& state_shared .n_ready , 1 ) == n_threads - 1 ) {
9517
- atomic_store (& state_shared .has_work , false);
9582
+ ggml_mutex_lock (& state_shared .mutex );
9583
+ if (state_shared .n_ready ++ == n_threads - 1 ) {
9584
+ state_shared .has_work = false;
9585
+ ggml_cond_broadcast (& state_shared .cond );
9518
9586
}
9587
+ ggml_mutex_unlock (& state_shared .mutex );
9519
9588
9520
- while ( atomic_load ( & state_shared .has_work )) {
9521
- ggml_lock_lock ( & state_shared .spin );
9522
- ggml_lock_unlock (& state_shared .spin );
9589
+ ggml_mutex_lock ( & state_shared .mutex );
9590
+ while ( state_shared .has_work ) {
9591
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9523
9592
}
9593
+ ggml_mutex_unlock (& state_shared .mutex );
9524
9594
9525
- atomic_fetch_sub (& state_shared .n_ready , 1 );
9595
+ ggml_mutex_lock (& state_shared .mutex );
9596
+ state_shared .n_ready -- ;
9597
+ ggml_cond_broadcast (& state_shared .cond );
9598
+ ggml_mutex_unlock (& state_shared .mutex );
9526
9599
9527
- while ( atomic_load ( & state_shared .n_ready ) != 0 ) {
9528
- ggml_lock_lock ( & state_shared .spin );
9529
- ggml_lock_unlock (& state_shared .spin );
9600
+ ggml_mutex_lock ( & state_shared .mutex );
9601
+ while ( state_shared .n_ready != 0 ) {
9602
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9530
9603
}
9604
+ ggml_mutex_unlock (& state_shared .mutex );
9531
9605
}
9532
9606
9533
9607
// FINALIZE
9534
9608
if (node -> n_tasks > 1 ) {
9535
- if (atomic_fetch_add (& state_shared .n_ready , 1 ) == n_threads - 1 ) {
9536
- atomic_store (& state_shared .has_work , false);
9609
+ ggml_mutex_lock (& state_shared .mutex );
9610
+ if (state_shared .n_ready ++ == n_threads - 1 ) {
9611
+ state_shared .has_work = false;
9612
+ ggml_cond_broadcast (& state_shared .cond );
9537
9613
}
9614
+ ggml_mutex_unlock (& state_shared .mutex );
9538
9615
9539
- while ( atomic_load ( & state_shared .has_work )) {
9540
- ggml_lock_lock ( & state_shared .spin );
9541
- ggml_lock_unlock (& state_shared .spin );
9616
+ ggml_mutex_lock ( & state_shared .mutex );
9617
+ while ( state_shared .has_work ) {
9618
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9542
9619
}
9620
+ ggml_mutex_unlock (& state_shared .mutex );
9543
9621
9544
9622
// launch thread pool
9545
9623
for (int j = 0 ; j < n_threads - 1 ; j ++ ) {
@@ -9553,36 +9631,51 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9553
9631
workers [j ].node = node ;
9554
9632
}
9555
9633
9556
- atomic_fetch_sub (& state_shared .n_ready , 1 );
9634
+ ggml_mutex_lock (& state_shared .mutex );
9635
+ state_shared .n_ready -= 1 ;
9636
+ ggml_cond_broadcast (& state_shared .cond );
9637
+ ggml_mutex_unlock (& state_shared .mutex );
9557
9638
9558
- while ( atomic_load ( & state_shared .n_ready ) > 0 ) {
9559
- ggml_lock_lock ( & state_shared .spin );
9560
- ggml_lock_unlock (& state_shared .spin );
9639
+ ggml_mutex_lock ( & state_shared .mutex );
9640
+ while ( state_shared .n_ready > 0 ) {
9641
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9561
9642
}
9643
+ ggml_mutex_unlock (& state_shared .mutex );
9562
9644
9563
- atomic_store (& state_shared .has_work , true);
9645
+ ggml_mutex_lock (& state_shared .mutex );
9646
+ state_shared .has_work = true;
9647
+ ggml_cond_broadcast (& state_shared .cond );
9648
+ ggml_mutex_unlock (& state_shared .mutex );
9564
9649
}
9565
9650
9566
9651
params .type = GGML_TASK_FINALIZE ;
9567
9652
ggml_compute_forward (& params , node );
9568
9653
9569
9654
// wait for thread pool
9570
9655
if (node -> n_tasks > 1 ) {
9571
- if (atomic_fetch_add (& state_shared .n_ready , 1 ) == n_threads - 1 ) {
9572
- atomic_store (& state_shared .has_work , false);
9656
+ ggml_mutex_lock (& state_shared .mutex );
9657
+ if (state_shared .n_ready ++ == n_threads - 1 ) {
9658
+ state_shared .has_work = false;
9659
+ ggml_cond_broadcast (& state_shared .cond );
9573
9660
}
9661
+ ggml_mutex_unlock (& state_shared .mutex );
9574
9662
9575
- while ( atomic_load ( & state_shared .has_work )) {
9576
- ggml_lock_lock ( & state_shared .spin );
9577
- ggml_lock_unlock (& state_shared .spin );
9663
+ ggml_mutex_lock ( & state_shared .mutex );
9664
+ while ( state_shared .has_work ) {
9665
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9578
9666
}
9667
+ ggml_mutex_unlock (& state_shared .mutex );
9579
9668
9580
- atomic_fetch_sub (& state_shared .n_ready , 1 );
9669
+ ggml_mutex_lock (& state_shared .mutex );
9670
+ state_shared .n_ready -- ;
9671
+ ggml_cond_broadcast (& state_shared .cond );
9672
+ ggml_mutex_unlock (& state_shared .mutex );
9581
9673
9582
- while ( atomic_load ( & state_shared .n_ready ) != 0 ) {
9583
- ggml_lock_lock ( & state_shared .spin );
9584
- ggml_lock_unlock (& state_shared .spin );
9674
+ ggml_mutex_lock ( & state_shared .mutex );
9675
+ while ( state_shared .n_ready != 0 ) {
9676
+ ggml_cond_wait (& state_shared .cond , & state_shared . mutex );
9585
9677
}
9678
+ ggml_mutex_unlock (& state_shared .mutex );
9586
9679
}
9587
9680
9588
9681
// performance stats (node)
@@ -9598,16 +9691,19 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9598
9691
9599
9692
// join thread pool
9600
9693
if (n_threads > 1 ) {
9601
- atomic_store (& state_shared .stop , true);
9602
- atomic_store (& state_shared .has_work , true);
9603
-
9694
+ ggml_mutex_lock (& state_shared .mutex );
9695
+ state_shared .stop = true;
9696
+ state_shared .has_work = true;
9697
+ ggml_cond_broadcast (& state_shared .cond );
9698
+ ggml_mutex_unlock (& state_shared .mutex );
9604
9699
for (int j = 0 ; j < n_threads - 1 ; j ++ ) {
9605
9700
int rc = ggml_thread_join (workers [j ].thrd , NULL );
9606
9701
GGML_ASSERT (rc == 0 );
9607
9702
UNUSED (rc );
9608
9703
}
9609
9704
9610
- ggml_lock_destroy (& state_shared .spin );
9705
+ ggml_cond_destroy (& state_shared .cond );
9706
+ ggml_mutex_destroy (& state_shared .mutex );
9611
9707
}
9612
9708
9613
9709
// performance stats (graph)
0 commit comments