Skip to content

Commit c2a7bfa

Browse files
committed
Per @iotamudelta suggestion until the deadlocks issue is better understood
Revert "Optimize custom all reduce (#130)" This reverts commit 636ff01.
1 parent 481bf03 commit c2a7bfa

File tree

3 files changed

+77
-117
lines changed

3 files changed

+77
-117
lines changed

csrc/custom_all_reduce.cuh

Lines changed: 73 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,7 @@ struct __align__(16) RankData { const void* ptrs[8]; };
4343
struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
4444
#endif
4545

46-
struct __align__(16) RankSignals {
47-
#ifndef USE_ROCM
48-
volatile
49-
#endif
50-
Signal* signals[8];
51-
};
46+
struct __align__(16) RankSignals { volatile Signal* signals[8]; };
5247

5348
// like std::array, but aligned
5449
template <typename T, int sz>
@@ -141,29 +136,25 @@ DINLINE O downcast(array_t<float, O::size> val) {
141136
// This function is meant to be used as the first synchronization in the all
142137
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
143138
// prior memory accesses. Note: volatile writes will not be reordered against
144-
// other volatile writes (CUDA-only).
139+
// other volatile writes.
145140
template <int ngpus>
141+
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
142+
int rank) {
146143
#ifdef USE_ROCM
147-
DINLINE void start_sync(const RankSignals& sg, Signal* self_sg, int rank) {
148144
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
149145
if (threadIdx.x < ngpus) {
150-
__scoped_atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0,
151-
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
152146
// simultaneously write to the corresponding flag of all ranks.
153147
// Latency = 1 p2p write
154-
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
155-
1, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
156-
__atomic_thread_fence(__ATOMIC_ACQ_REL);
148+
__atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag,
149+
__ATOMIC_RELAXED);
157150
// wait until we got true from all ranks
158-
while (!__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
159-
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE));
151+
while (__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
152+
__ATOMIC_RELAXED) < flag);
160153
}
161154
__syncthreads();
162-
}
155+
// use one thread to update flag
156+
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
163157
#else
164-
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
165-
int rank) {
166-
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
167158
if (threadIdx.x < ngpus) {
168159
// reset flag for next time
169160
self_sg->end[blockIdx.x][threadIdx.x] = 0;
@@ -174,38 +165,36 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
174165
while (!self_sg->start[blockIdx.x][threadIdx.x]);
175166
}
176167
__syncthreads();
177-
}
178168
#endif
169+
}
179170

180171
// This function is meant to be used as the second or the final synchronization
181172
// barrier in the all reduce kernel. If it's the final synchronization barrier,
182173
// we don't need to make any visibility guarantees for prior memory accesses.
183174
template <int ngpus, bool final_sync = false>
175+
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
176+
int rank) {
184177
#ifdef USE_ROCM
185-
DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) {
186178
__syncthreads();
187179
// eliminate the case that prior writes are not visible after signals become
188180
// visible. Note that I did not managed to make this happen through a lot of
189181
// testing. Might be the case that hardware provides stronger guarantee than
190182
// the memory model.
183+
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
191184
if (threadIdx.x < ngpus) {
192-
// reset flag for next time
193-
__scoped_atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0,
194-
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
195185
// simultaneously write to the corresponding flag of all ranks.
196186
// Latency = 1 p2p write
197-
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1,
198-
__ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
199-
__atomic_thread_fence(__ATOMIC_ACQ_REL);
187+
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
188+
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
200189
// wait until we got true from all ranks
201-
while (!__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
202-
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE));
190+
while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
191+
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
192+
flag);
203193
}
204-
if constexpr (!final_sync) __syncthreads();
205-
}
194+
__syncthreads();
195+
// use one thread to update flag
196+
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
206197
#else
207-
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
208-
int rank) {
209198
__syncthreads();
210199
// eliminate the case that prior writes are not visible after signals become
211200
// visible. Note that I did not managed to make this happen through a lot of
@@ -222,8 +211,8 @@ DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
222211
while (!self_sg->end[blockIdx.x][threadIdx.x]);
223212
}
224213
if constexpr (!final_sync) __syncthreads();
225-
}
226214
#endif
215+
}
227216

228217
template <typename P, int ngpus, typename A>
229218
DINLINE P packed_reduce(const P* ptrs[], int idx) {
@@ -238,11 +227,8 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
238227
template <typename T, int ngpus>
239228
__global__ void __launch_bounds__(512, 1)
240229
cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
241-
#ifndef USE_ROCM
242-
volatile
243-
#endif
244-
Signal* self_sg,
245-
T* __restrict__ result, int rank, int size) {
230+
volatile Signal* self_sg, T* __restrict__ result,
231+
int rank, int size) {
246232
using P = typename packed_t<T>::P;
247233
using A = typename packed_t<T>::A;
248234
// note: we don't reorder the address so the accumulation order is the same
@@ -258,22 +244,15 @@ __global__ void __launch_bounds__(512, 1)
258244
}
259245

260246
template <typename P>
261-
DINLINE P* get_tmp_buf(
262-
#ifndef USE_ROCM
263-
volatile
264-
#endif
265-
Signal* sg) {
247+
DINLINE P* get_tmp_buf(volatile Signal* sg) {
266248
return (P*)(((Signal*)sg) + 1);
267249
}
268250

269251
template <typename T, int ngpus>
270252
__global__ void __launch_bounds__(512, 1)
271253
cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
272-
#ifndef USE_ROCM
273-
volatile
274-
#endif
275-
Signal* self_sg,
276-
T* __restrict__ result, int rank, int size) {
254+
volatile Signal* self_sg, T* __restrict__ result,
255+
int rank, int size) {
277256
int tid = blockIdx.x * blockDim.x + threadIdx.x;
278257
int stride = gridDim.x * blockDim.x;
279258
using P = typename packed_t<T>::P;
@@ -476,41 +455,37 @@ class CustomAllreduce {
476455
*/
477456
template <typename T>
478457
void allreduce(cudaStream_t stream, T* input, T* output, int size,
479-
#ifdef USE_ROCM
480-
int threads = 512, int block_limit = 18){
481-
#else
482458
int threads = 512, int block_limit = 36) {
483-
#endif
484-
auto d = packed_t<T>::P::size;
485-
if (size % d != 0)
486-
throw std::runtime_error(
487-
"custom allreduce currently requires input length to be multiple "
488-
"of " +
489-
std::to_string(d));
490-
if (block_limit > kMaxBlocks)
491-
throw std::runtime_error("max supported block limit is " +
492-
std::to_string(kMaxBlocks) + ". Got " +
493-
std::to_string(block_limit));
494-
495-
RankData* ptrs;
496-
cudaStreamCaptureStatus status;
497-
CUDACHECK(cudaStreamIsCapturing(stream, &status));
498-
if (status == cudaStreamCaptureStatusActive) {
499-
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
500-
graph_unreg_buffers_.push_back(input);
501-
} else {
502-
auto it = buffers_.find(input);
503-
if (it == buffers_.end())
459+
auto d = packed_t<T>::P::size;
460+
if (size % d != 0)
504461
throw std::runtime_error(
505-
"buffer address " +
506-
std::to_string(reinterpret_cast<uint64_t>(input)) +
507-
" is not registered!");
508-
ptrs = it->second;
509-
}
462+
"custom allreduce currently requires input length to be multiple "
463+
"of " +
464+
std::to_string(d));
465+
if (block_limit > kMaxBlocks)
466+
throw std::runtime_error("max supported block limit is " +
467+
std::to_string(kMaxBlocks) + ". Got " +
468+
std::to_string(block_limit));
469+
470+
RankData* ptrs;
471+
cudaStreamCaptureStatus status;
472+
CUDACHECK(cudaStreamIsCapturing(stream, &status));
473+
if (status == cudaStreamCaptureStatusActive) {
474+
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
475+
graph_unreg_buffers_.push_back(input);
476+
} else {
477+
auto it = buffers_.find(input);
478+
if (it == buffers_.end())
479+
throw std::runtime_error(
480+
"buffer address " +
481+
std::to_string(reinterpret_cast<uint64_t>(input)) +
482+
" is not registered!");
483+
ptrs = it->second;
484+
}
510485

511-
size /= d;
512-
auto bytes = size * sizeof(typename packed_t<T>::P);
513-
int blocks = std::min(block_limit, (size + threads - 1) / threads);
486+
size /= d;
487+
auto bytes = size * sizeof(typename packed_t<T>::P);
488+
int blocks = std::min(block_limit, (size + threads - 1) / threads);
514489
#define KL(ngpus, name) \
515490
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
516491
rank_, size);
@@ -529,27 +504,27 @@ class CustomAllreduce {
529504
break; \
530505
}
531506

532-
switch (world_size_) {
533-
REDUCE_CASE(2)
534-
REDUCE_CASE(4)
535-
REDUCE_CASE(6)
536-
REDUCE_CASE(8)
537-
default:
538-
throw std::runtime_error(
539-
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
540-
"gpus = " +
541-
std::to_string(world_size_));
542-
}
507+
switch (world_size_) {
508+
REDUCE_CASE(2)
509+
REDUCE_CASE(4)
510+
REDUCE_CASE(6)
511+
REDUCE_CASE(8)
512+
default:
513+
throw std::runtime_error(
514+
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
515+
"gpus = " +
516+
std::to_string(world_size_));
517+
}
543518
#undef REDUCE_CASE
544519
#undef KL
545-
}
520+
}
546521

547-
~CustomAllreduce() {
548-
for (auto [_, ptr] : ipc_handles_) {
549-
CUDACHECK(cudaIpcCloseMemHandle(ptr));
522+
~CustomAllreduce() {
523+
for (auto [_, ptr] : ipc_handles_) {
524+
CUDACHECK(cudaIpcCloseMemHandle(ptr));
525+
}
550526
}
551-
}
552-
}; // namespace vllm
527+
};
553528
/**
554529
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
555530
a template instantiation:

csrc/custom_all_reduce_test.cu

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -330,17 +330,10 @@ int main(int argc, char** argv) {
330330
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
331331
// }
332332
// }
333-
#ifdef USE _ROCM
334-
for (int sz = 512; sz <= (8 << 22); sz *= 2) {
335-
run<half>(myRank, nRanks, comm, 512, 18, sz + 8 * 47, performance_test);
336-
}
337-
#else
338333
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
339334
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
340335
}
341-
#endif
342336

343337
cudaProfilerStop();
344-
MPICHECK(MPI_Finalize());
345338
return EXIT_SUCCESS;
346339
}

vllm/distributed/parallel_state.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -199,18 +199,10 @@ def initialize_model_parallel(
199199
if _ENABLE_CUSTOM_ALL_REDUCE:
200200
from vllm.distributed.device_communicators.custom_all_reduce import (
201201
CustomAllreduce)
202-
203-
# max size defaults to 8 MiB, increase to 16 MiB on ROCm
204-
# due to later crossover
205-
if is_hip():
206-
_TP_CA_COMMUNICATOR = CustomAllreduce(group=_TP_CPU_GROUP,
207-
device=_LOCAL_RANK,
208-
max_size=2 * 8192 * 1024)
209-
else:
210-
_TP_CA_COMMUNICATOR = CustomAllreduce(
211-
group=_TP_CPU_GROUP,
212-
device=_LOCAL_RANK,
213-
)
202+
_TP_CA_COMMUNICATOR = CustomAllreduce(
203+
group=_TP_CPU_GROUP,
204+
device=_LOCAL_RANK,
205+
)
214206

215207
# Build the pipeline model-parallel groups.
216208
global _PP_DEVICE_GROUP, _PP_CPU_GROUP

0 commit comments

Comments
 (0)