Skip to content

Commit 481bf03

Browse files
committed
Per @iotamudelta suggestion until the deadlocks issue is better understood
Revert "Make CAR ROCm 6.1 compatible. (#137)" This reverts commit 4d2dda6.
1 parent 4d2dda6 commit 481bf03

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

csrc/custom_all_reduce.cuh

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -145,17 +145,18 @@ DINLINE O downcast(array_t<float, O::size> val) {
145145
template <int ngpus>
146146
#ifdef USE_ROCM
147147
DINLINE void start_sync(const RankSignals& sg, Signal* self_sg, int rank) {
148+
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
148149
if (threadIdx.x < ngpus) {
149-
__atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0,
150-
__ATOMIC_RELAXED);
150+
__scoped_atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0,
151+
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
151152
// simultaneously write to the corresponding flag of all ranks.
152153
// Latency = 1 p2p write
153-
__atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], 1,
154-
__ATOMIC_RELAXED);
154+
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
155+
1, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
155156
__atomic_thread_fence(__ATOMIC_ACQ_REL);
156157
// wait until we got true from all ranks
157-
while (!__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
158-
__ATOMIC_RELAXED);
158+
while (!__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
159+
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE));
159160
}
160161
__syncthreads();
161162
}
@@ -189,16 +190,16 @@ DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) {
189190
// the memory model.
190191
if (threadIdx.x < ngpus) {
191192
// reset flag for next time
192-
__atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0,
193-
__ATOMIC_RELAXED);
193+
__scoped_atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0,
194+
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
194195
// simultaneously write to the corresponding flag of all ranks.
195196
// Latency = 1 p2p write
196-
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1,
197-
__ATOMIC_RELAXED);
197+
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1,
198+
__ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
198199
__atomic_thread_fence(__ATOMIC_ACQ_REL);
199200
// wait until we got true from all ranks
200-
while (!__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
201-
__ATOMIC_RELAXED));
201+
while (!__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
202+
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE));
202203
}
203204
if constexpr (!final_sync) __syncthreads();
204205
}

csrc/custom_all_reduce_test.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ int main(int argc, char** argv) {
330330
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
331331
// }
332332
// }
333-
#ifdef USE_ROCM
333+
#ifdef USE _ROCM
334334
for (int sz = 512; sz <= (8 << 22); sz *= 2) {
335335
run<half>(myRank, nRanks, comm, 512, 18, sz + 8 * 47, performance_test);
336336
}

0 commit comments

Comments
 (0)