@@ -145,17 +145,18 @@ DINLINE O downcast(array_t<float, O::size> val) {
145145template <int ngpus>
146146#ifdef USE_ROCM
147147DINLINE void start_sync (const RankSignals& sg, Signal* self_sg, int rank) {
148+ uint32_t flag = self_sg->_flag [blockIdx .x ] + 1 ;
148149 if (threadIdx .x < ngpus) {
149- __atomic_store_n (&self_sg->end [blockIdx .x ][threadIdx .x ], 0 ,
150- __ATOMIC_RELAXED);
150+ __scoped_atomic_store_n (&self_sg->end [blockIdx .x ][threadIdx .x ], 0 ,
151+ __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE );
151152 // simultaneously write to the corresponding flag of all ranks.
152153 // Latency = 1 p2p write
153- __atomic_store_n (&sg.signals [threadIdx .x ]->start [blockIdx .x ][rank], 1 ,
154- __ATOMIC_RELAXED);
154+ __scoped_atomic_store_n (&sg.signals [threadIdx .x ]->start [blockIdx .x ][rank],
155+ 1 , __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM );
155156 __atomic_thread_fence (__ATOMIC_ACQ_REL);
156157 // wait until we got true from all ranks
157- while (!__atomic_load_n (&self_sg->start [blockIdx .x ][threadIdx .x ],
158- __ATOMIC_RELAXED);
158+ while (!__scoped_atomic_load_n (&self_sg->start [blockIdx .x ][threadIdx .x ],
159+ __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) );
159160 }
160161 __syncthreads ();
161162}
@@ -189,16 +190,16 @@ DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) {
189190 // the memory model.
190191 if (threadIdx .x < ngpus) {
191192 // reset flag for next time
192- __atomic_store_n (&self_sg->start [blockIdx .x ][threadIdx .x ], 0 ,
193- __ATOMIC_RELAXED);
193+ __scoped_atomic_store_n (&self_sg->start [blockIdx .x ][threadIdx .x ], 0 ,
194+ __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE );
194195 // simultaneously write to the corresponding flag of all ranks.
195196 // Latency = 1 p2p write
196- __atomic_store_n (&sg.signals [threadIdx .x ]->end [blockIdx .x ][rank], 1 ,
197- __ATOMIC_RELAXED);
197+ __scoped_atomic_store_n (&sg.signals [threadIdx .x ]->end [blockIdx .x ][rank], 1 ,
198+ __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM );
198199 __atomic_thread_fence (__ATOMIC_ACQ_REL);
199200 // wait until we got true from all ranks
200- while (!__atomic_load_n (&self_sg->end [blockIdx .x ][threadIdx .x ],
201- __ATOMIC_RELAXED));
201+ while (!__scoped_atomic_load_n (&self_sg->end [blockIdx .x ][threadIdx .x ],
202+ __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE ));
202203 }
203204 if constexpr (!final_sync) __syncthreads ();
204205}
0 commit comments