Skip to content

Commit 87bdf2a

Browse files
authored
ggml : use atomic_flag for critical section (#7598)
* ggml : use atomic_flag for critical section * add windows shims
1 parent 00281b7 commit 87bdf2a

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

ggml.c

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@
6060

6161
typedef volatile LONG atomic_int;
6262
typedef atomic_int atomic_bool;
63+
typedef atomic_int atomic_flag;
64+
65+
#define ATOMIC_FLAG_INIT 0
6366

6467
static void atomic_store(atomic_int * ptr, LONG val) {
6568
InterlockedExchange(ptr, val);
@@ -73,6 +76,12 @@ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
7376
static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
7477
return atomic_fetch_add(ptr, -(dec));
7578
}
79+
static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
80+
return InterlockedExchange(ptr, 1);
81+
}
82+
static void atomic_flag_clear(atomic_flag * ptr) {
83+
InterlockedExchange(ptr, 0);
84+
}
7685

7786
typedef HANDLE pthread_t;
7887

@@ -2883,24 +2892,20 @@ struct ggml_state {
28832892

28842893
// global state
28852894
static struct ggml_state g_state;
2886-
static atomic_int g_state_barrier = 0;
2895+
static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
28872896

28882897
// barrier via spin lock
28892898
inline static void ggml_critical_section_start(void) {
2890-
int processing = atomic_fetch_add(&g_state_barrier, 1);
2891-
2892-
while (processing > 0) {
2893-
// wait for other threads to finish
2894-
atomic_fetch_sub(&g_state_barrier, 1);
2895-
sched_yield(); // TODO: reconsider this
2896-
processing = atomic_fetch_add(&g_state_barrier, 1);
2899+
while (atomic_flag_test_and_set(&g_state_critical)) {
2900+
// spin
2901+
sched_yield();
28972902
}
28982903
}
28992904

29002905
// TODO: make this somehow automatically executed
29012906
// some sort of "sentry" mechanism
29022907
inline static void ggml_critical_section_end(void) {
2903-
atomic_fetch_sub(&g_state_barrier, 1);
2908+
atomic_flag_clear(&g_state_critical);
29042909
}
29052910

29062911
#if defined(__gnu_linux__)

0 commit comments

Comments
 (0)