Skip to content

Commit 636ff01

Browse files
authored
Optimize custom all reduce (#130)
* First version * Revert error. While there, add missing finalize. * Use the correct defaults for ROCm. Increase sampling area to capture crossover. * Scope end_sync as well. * Guard only volatile keyword for ifndef USE_ROCM * Document crossover
1 parent 674da1d commit 636ff01

File tree

3 files changed

+117
-77
lines changed

3 files changed

+117
-77
lines changed

csrc/custom_all_reduce.cuh

Lines changed: 98 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@ struct __align__(16) RankData { const void* ptrs[8]; };
4343
struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
4444
#endif
4545

46-
struct __align__(16) RankSignals { volatile Signal* signals[8]; };
46+
struct __align__(16) RankSignals {
47+
#ifndef USE_ROCM
48+
volatile
49+
#endif
50+
Signal* signals[8];
51+
};
4752

4853
// like std::array, but aligned
4954
template <typename T, int sz>
@@ -136,25 +141,29 @@ DINLINE O downcast(array_t<float, O::size> val) {
136141
// This function is meant to be used as the first synchronization in the all
137142
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
138143
// prior memory accesses. Note: volatile writes will not be reordered against
139-
// other volatile writes.
144+
// other volatile writes (CUDA-only).
140145
template <int ngpus>
141-
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
142-
int rank) {
143146
#ifdef USE_ROCM
147+
DINLINE void start_sync(const RankSignals& sg, Signal* self_sg, int rank) {
144148
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
145149
if (threadIdx.x < ngpus) {
150+
__scoped_atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0,
151+
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
146152
// simultaneously write to the corresponding flag of all ranks.
147153
// Latency = 1 p2p write
148-
__atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag,
149-
__ATOMIC_RELAXED);
154+
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
155+
1, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
156+
__atomic_thread_fence(__ATOMIC_ACQ_REL);
150157
// wait until we got true from all ranks
151-
while (__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
152-
__ATOMIC_RELAXED) < flag);
158+
while (!__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
159+
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE));
153160
}
154161
__syncthreads();
155-
// use one thread to update flag
156-
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
162+
}
157163
#else
164+
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
165+
int rank) {
166+
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
158167
if (threadIdx.x < ngpus) {
159168
// reset flag for next time
160169
self_sg->end[blockIdx.x][threadIdx.x] = 0;
@@ -165,36 +174,38 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
165174
while (!self_sg->start[blockIdx.x][threadIdx.x]);
166175
}
167176
__syncthreads();
168-
#endif
169177
}
178+
#endif
170179

171180
// This function is meant to be used as the second or the final synchronization
172181
// barrier in the all reduce kernel. If it's the final synchronization barrier,
173182
// we don't need to make any visibility guarantees for prior memory accesses.
174183
template <int ngpus, bool final_sync = false>
175-
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
176-
int rank) {
177184
#ifdef USE_ROCM
185+
DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) {
178186
__syncthreads();
179187
// eliminate the case that prior writes are not visible after signals become
180188
// visible. Note that I did not managed to make this happen through a lot of
181189
// testing. Might be the case that hardware provides stronger guarantee than
182190
// the memory model.
183-
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
184191
if (threadIdx.x < ngpus) {
192+
// reset flag for next time
193+
__scoped_atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0,
194+
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
185195
// simultaneously write to the corresponding flag of all ranks.
186196
// Latency = 1 p2p write
187-
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
188-
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
197+
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1,
198+
__ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
199+
__atomic_thread_fence(__ATOMIC_ACQ_REL);
189200
// wait until we got true from all ranks
190-
while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
191-
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
192-
flag);
201+
while (!__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
202+
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE));
193203
}
194-
__syncthreads();
195-
// use one thread to update flag
196-
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
204+
if constexpr (!final_sync) __syncthreads();
205+
}
197206
#else
207+
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
208+
int rank) {
198209
__syncthreads();
199210
// eliminate the case that prior writes are not visible after signals become
200211
// visible. Note that I did not managed to make this happen through a lot of
@@ -211,8 +222,8 @@ DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
211222
while (!self_sg->end[blockIdx.x][threadIdx.x]);
212223
}
213224
if constexpr (!final_sync) __syncthreads();
214-
#endif
215225
}
226+
#endif
216227

217228
template <typename P, int ngpus, typename A>
218229
DINLINE P packed_reduce(const P* ptrs[], int idx) {
@@ -227,8 +238,11 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
227238
template <typename T, int ngpus>
228239
__global__ void __launch_bounds__(512, 1)
229240
cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
230-
volatile Signal* self_sg, T* __restrict__ result,
231-
int rank, int size) {
241+
#ifndef USE_ROCM
242+
volatile
243+
#endif
244+
Signal* self_sg,
245+
T* __restrict__ result, int rank, int size) {
232246
using P = typename packed_t<T>::P;
233247
using A = typename packed_t<T>::A;
234248
// note: we don't reorder the address so the accumulation order is the same
@@ -244,15 +258,22 @@ __global__ void __launch_bounds__(512, 1)
244258
}
245259

246260
template <typename P>
247-
DINLINE P* get_tmp_buf(volatile Signal* sg) {
261+
DINLINE P* get_tmp_buf(
262+
#ifndef USE_ROCM
263+
volatile
264+
#endif
265+
Signal* sg) {
248266
return (P*)(((Signal*)sg) + 1);
249267
}
250268

251269
template <typename T, int ngpus>
252270
__global__ void __launch_bounds__(512, 1)
253271
cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
254-
volatile Signal* self_sg, T* __restrict__ result,
255-
int rank, int size) {
272+
#ifndef USE_ROCM
273+
volatile
274+
#endif
275+
Signal* self_sg,
276+
T* __restrict__ result, int rank, int size) {
256277
int tid = blockIdx.x * blockDim.x + threadIdx.x;
257278
int stride = gridDim.x * blockDim.x;
258279
using P = typename packed_t<T>::P;
@@ -455,37 +476,41 @@ class CustomAllreduce {
455476
*/
456477
template <typename T>
457478
void allreduce(cudaStream_t stream, T* input, T* output, int size,
479+
#ifdef USE_ROCM
480+
int threads = 512, int block_limit = 18){
481+
#else
458482
int threads = 512, int block_limit = 36) {
459-
auto d = packed_t<T>::P::size;
460-
if (size % d != 0)
483+
#endif
484+
auto d = packed_t<T>::P::size;
485+
if (size % d != 0)
486+
throw std::runtime_error(
487+
"custom allreduce currently requires input length to be multiple "
488+
"of " +
489+
std::to_string(d));
490+
if (block_limit > kMaxBlocks)
491+
throw std::runtime_error("max supported block limit is " +
492+
std::to_string(kMaxBlocks) + ". Got " +
493+
std::to_string(block_limit));
494+
495+
RankData* ptrs;
496+
cudaStreamCaptureStatus status;
497+
CUDACHECK(cudaStreamIsCapturing(stream, &status));
498+
if (status == cudaStreamCaptureStatusActive) {
499+
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
500+
graph_unreg_buffers_.push_back(input);
501+
} else {
502+
auto it = buffers_.find(input);
503+
if (it == buffers_.end())
461504
throw std::runtime_error(
462-
"custom allreduce currently requires input length to be multiple "
463-
"of " +
464-
std::to_string(d));
465-
if (block_limit > kMaxBlocks)
466-
throw std::runtime_error("max supported block limit is " +
467-
std::to_string(kMaxBlocks) + ". Got " +
468-
std::to_string(block_limit));
469-
470-
RankData* ptrs;
471-
cudaStreamCaptureStatus status;
472-
CUDACHECK(cudaStreamIsCapturing(stream, &status));
473-
if (status == cudaStreamCaptureStatusActive) {
474-
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
475-
graph_unreg_buffers_.push_back(input);
476-
} else {
477-
auto it = buffers_.find(input);
478-
if (it == buffers_.end())
479-
throw std::runtime_error(
480-
"buffer address " +
481-
std::to_string(reinterpret_cast<uint64_t>(input)) +
482-
" is not registered!");
483-
ptrs = it->second;
484-
}
505+
"buffer address " +
506+
std::to_string(reinterpret_cast<uint64_t>(input)) +
507+
" is not registered!");
508+
ptrs = it->second;
509+
}
485510

486-
size /= d;
487-
auto bytes = size * sizeof(typename packed_t<T>::P);
488-
int blocks = std::min(block_limit, (size + threads - 1) / threads);
511+
size /= d;
512+
auto bytes = size * sizeof(typename packed_t<T>::P);
513+
int blocks = std::min(block_limit, (size + threads - 1) / threads);
489514
#define KL(ngpus, name) \
490515
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
491516
rank_, size);
@@ -504,27 +529,27 @@ class CustomAllreduce {
504529
break; \
505530
}
506531

507-
switch (world_size_) {
508-
REDUCE_CASE(2)
509-
REDUCE_CASE(4)
510-
REDUCE_CASE(6)
511-
REDUCE_CASE(8)
512-
default:
513-
throw std::runtime_error(
514-
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
515-
"gpus = " +
516-
std::to_string(world_size_));
517-
}
532+
switch (world_size_) {
533+
REDUCE_CASE(2)
534+
REDUCE_CASE(4)
535+
REDUCE_CASE(6)
536+
REDUCE_CASE(8)
537+
default:
538+
throw std::runtime_error(
539+
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
540+
"gpus = " +
541+
std::to_string(world_size_));
542+
}
518543
#undef REDUCE_CASE
519544
#undef KL
520-
}
545+
}
521546

522-
~CustomAllreduce() {
523-
for (auto [_, ptr] : ipc_handles_) {
524-
CUDACHECK(cudaIpcCloseMemHandle(ptr));
525-
}
547+
~CustomAllreduce() {
548+
for (auto [_, ptr] : ipc_handles_) {
549+
CUDACHECK(cudaIpcCloseMemHandle(ptr));
526550
}
527-
};
551+
}
552+
}; // namespace vllm
528553
/**
529554
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
530555
a template instantiation:

csrc/custom_all_reduce_test.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,10 +330,17 @@ int main(int argc, char** argv) {
330330
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
331331
// }
332332
// }
333+
#ifdef USE _ROCM
334+
for (int sz = 512; sz <= (8 << 22); sz *= 2) {
335+
run<half>(myRank, nRanks, comm, 512, 18, sz + 8 * 47, performance_test);
336+
}
337+
#else
333338
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
334339
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
335340
}
341+
#endif
336342

337343
cudaProfilerStop();
344+
MPICHECK(MPI_Finalize());
338345
return EXIT_SUCCESS;
339346
}

vllm/distributed/parallel_state.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,18 @@ def initialize_model_parallel(
199199
if _ENABLE_CUSTOM_ALL_REDUCE:
200200
from vllm.distributed.device_communicators.custom_all_reduce import (
201201
CustomAllreduce)
202-
_TP_CA_COMMUNICATOR = CustomAllreduce(
203-
group=_TP_CPU_GROUP,
204-
device=_LOCAL_RANK,
205-
)
202+
203+
# max size defaults to 8 MiB, increase to 16 MiB on ROCm
204+
# due to later crossover
205+
if is_hip():
206+
_TP_CA_COMMUNICATOR = CustomAllreduce(group=_TP_CPU_GROUP,
207+
device=_LOCAL_RANK,
208+
max_size=2 * 8192 * 1024)
209+
else:
210+
_TP_CA_COMMUNICATOR = CustomAllreduce(
211+
group=_TP_CPU_GROUP,
212+
device=_LOCAL_RANK,
213+
)
206214

207215
# Build the pipeline model-parallel groups.
208216
global _PP_DEVICE_GROUP, _PP_CPU_GROUP

0 commit comments

Comments
 (0)