@@ -43,12 +43,7 @@ struct __align__(16) RankData { const void* ptrs[8]; };
4343struct __align__ (16 ) RankData { const void * __restrict__ ptrs[8 ]; };
4444#endif
4545
46- struct __align__ (16 ) RankSignals {
47- #ifndef USE_ROCM
48- volatile
49- #endif
50- Signal* signals[8 ];
51- };
46+ struct __align__ (16 ) RankSignals { volatile Signal* signals[8 ]; };
5247
5348// like std::array, but aligned
5449template <typename T, int sz>
@@ -141,29 +136,25 @@ DINLINE O downcast(array_t<float, O::size> val) {
141136// This function is meant to be used as the first synchronization in the all
142137// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
143138// prior memory accesses. Note: volatile writes will not be reordered against
144- // other volatile writes (CUDA-only) .
139+ // other volatile writes.
145140template <int ngpus>
141+ DINLINE void start_sync (const RankSignals& sg, volatile Signal* self_sg,
142+ int rank) {
146143#ifdef USE_ROCM
147- DINLINE void start_sync (const RankSignals& sg, Signal* self_sg, int rank) {
148144 uint32_t flag = self_sg->_flag [blockIdx .x ] + 1 ;
149145 if (threadIdx .x < ngpus) {
150- __scoped_atomic_store_n (&self_sg->end [blockIdx .x ][threadIdx .x ], 0 ,
151- __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
152146 // simultaneously write to the corresponding flag of all ranks.
153147 // Latency = 1 p2p write
154- __scoped_atomic_store_n (&sg.signals [threadIdx .x ]->start [blockIdx .x ][rank],
155- 1 , __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
156- __atomic_thread_fence (__ATOMIC_ACQ_REL);
148+ __atomic_store_n (&sg.signals [threadIdx .x ]->start [blockIdx .x ][rank], flag,
149+ __ATOMIC_RELAXED);
157150 // wait until we got true from all ranks
158- while (! __scoped_atomic_load_n (&self_sg->start [blockIdx .x ][threadIdx .x ],
159- __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) );
151+ while (__atomic_load_n (&self_sg->start [blockIdx .x ][threadIdx .x ],
152+ __ATOMIC_RELAXED) < flag );
160153 }
161154 __syncthreads ();
162- }
155+ // use one thread to update flag
156+ if (threadIdx .x == 0 ) self_sg->_flag [blockIdx .x ] = flag;
163157#else
164- DINLINE void start_sync (const RankSignals& sg, volatile Signal* self_sg,
165- int rank) {
166- uint32_t flag = self_sg->_flag [blockIdx .x ] + 1 ;
167158 if (threadIdx .x < ngpus) {
168159 // reset flag for next time
169160 self_sg->end [blockIdx .x ][threadIdx .x ] = 0 ;
@@ -174,38 +165,36 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
174165 while (!self_sg->start [blockIdx .x ][threadIdx .x ]);
175166 }
176167 __syncthreads ();
177- }
178168#endif
169+ }
179170
180171// This function is meant to be used as the second or the final synchronization
181172// barrier in the all reduce kernel. If it's the final synchronization barrier,
182173// we don't need to make any visibility guarantees for prior memory accesses.
183174template <int ngpus, bool final_sync = false >
175+ DINLINE void end_sync (const RankSignals& sg, volatile Signal* self_sg,
176+ int rank) {
184177#ifdef USE_ROCM
185- DINLINE void end_sync (const RankSignals& sg, Signal* self_sg, int rank) {
186178 __syncthreads ();
187179 // eliminate the case that prior writes are not visible after signals become
188180 // visible. Note that I did not managed to make this happen through a lot of
189181 // testing. Might be the case that hardware provides stronger guarantee than
190182 // the memory model.
183+ uint32_t flag = self_sg->_flag [blockIdx .x ] + 1 ;
191184 if (threadIdx .x < ngpus) {
192- // reset flag for next time
193- __scoped_atomic_store_n (&self_sg->start [blockIdx .x ][threadIdx .x ], 0 ,
194- __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
195185 // simultaneously write to the corresponding flag of all ranks.
196186 // Latency = 1 p2p write
197- __scoped_atomic_store_n (&sg.signals [threadIdx .x ]->end [blockIdx .x ][rank], 1 ,
198- __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
199- __atomic_thread_fence (__ATOMIC_ACQ_REL);
187+ __atomic_store_n (&sg.signals [threadIdx .x ]->end [blockIdx .x ][rank], flag,
188+ final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
200189 // wait until we got true from all ranks
201- while (!__scoped_atomic_load_n (&self_sg->end [blockIdx .x ][threadIdx .x ],
202- __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE));
190+ while (__atomic_load_n (&self_sg->end [blockIdx .x ][threadIdx .x ],
191+ final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
192+ flag);
203193 }
204- if constexpr (!final_sync) __syncthreads ();
205- }
194+ __syncthreads ();
195+ // use one thread to update flag
196+ if (threadIdx .x == 0 ) self_sg->_flag [blockIdx .x ] = flag;
206197#else
207- DINLINE void end_sync (const RankSignals& sg, volatile Signal* self_sg,
208- int rank) {
209198 __syncthreads ();
210199 // eliminate the case that prior writes are not visible after signals become
211200 // visible. Note that I did not managed to make this happen through a lot of
@@ -222,8 +211,8 @@ DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
222211 while (!self_sg->end [blockIdx .x ][threadIdx .x ]);
223212 }
224213 if constexpr (!final_sync) __syncthreads ();
225- }
226214#endif
215+ }
227216
228217template <typename P, int ngpus, typename A>
229218DINLINE P packed_reduce (const P* ptrs[], int idx) {
@@ -238,11 +227,8 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
238227template <typename T, int ngpus>
239228__global__ void __launch_bounds__ (512 , 1 )
240229 cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
241- #ifndef USE_ROCM
242- volatile
243- #endif
244- Signal* self_sg,
245- T* __restrict__ result, int rank, int size) {
230+ volatile Signal* self_sg, T* __restrict__ result,
231+ int rank, int size) {
246232 using P = typename packed_t <T>::P;
247233 using A = typename packed_t <T>::A;
248234 // note: we don't reorder the address so the accumulation order is the same
@@ -258,22 +244,15 @@ __global__ void __launch_bounds__(512, 1)
258244}
259245
260246template <typename P>
261- DINLINE P* get_tmp_buf (
262- #ifndef USE_ROCM
263- volatile
264- #endif
265- Signal* sg) {
247+ DINLINE P* get_tmp_buf (volatile Signal* sg) {
266248 return (P*)(((Signal*)sg) + 1 );
267249}
268250
269251template <typename T, int ngpus>
270252__global__ void __launch_bounds__ (512 , 1 )
271253 cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
272- #ifndef USE_ROCM
273- volatile
274- #endif
275- Signal* self_sg,
276- T* __restrict__ result, int rank, int size) {
254+ volatile Signal* self_sg, T* __restrict__ result,
255+ int rank, int size) {
277256 int tid = blockIdx .x * blockDim .x + threadIdx .x ;
278257 int stride = gridDim .x * blockDim .x ;
279258 using P = typename packed_t <T>::P;
@@ -476,41 +455,37 @@ class CustomAllreduce {
476455 */
477456 template <typename T>
478457 void allreduce (cudaStream_t stream, T* input, T* output, int size,
479- #ifdef USE_ROCM
480- int threads = 512 , int block_limit = 18 ){
481- #else
482458 int threads = 512 , int block_limit = 36 ) {
483- #endif
484- auto d = packed_t <T>::P::size;
485- if (size % d != 0 )
486- throw std::runtime_error (
487- " custom allreduce currently requires input length to be multiple "
488- " of " +
489- std::to_string (d));
490- if (block_limit > kMaxBlocks )
491- throw std::runtime_error (" max supported block limit is " +
492- std::to_string (kMaxBlocks ) + " . Got " +
493- std::to_string (block_limit));
494-
495- RankData* ptrs;
496- cudaStreamCaptureStatus status;
497- CUDACHECK (cudaStreamIsCapturing (stream, &status));
498- if (status == cudaStreamCaptureStatusActive) {
499- ptrs = d_rank_data_base_ + graph_unreg_buffers_.size ();
500- graph_unreg_buffers_.push_back (input);
501- } else {
502- auto it = buffers_.find (input);
503- if (it == buffers_.end ())
459+ auto d = packed_t <T>::P::size;
460+ if (size % d != 0 )
504461 throw std::runtime_error (
505- " buffer address " +
506- std::to_string (reinterpret_cast <uint64_t >(input)) +
507- " is not registered!" );
508- ptrs = it->second ;
509- }
462+ " custom allreduce currently requires input length to be multiple "
463+ " of " +
464+ std::to_string (d));
465+ if (block_limit > kMaxBlocks )
466+ throw std::runtime_error (" max supported block limit is " +
467+ std::to_string (kMaxBlocks ) + " . Got " +
468+ std::to_string (block_limit));
469+
470+ RankData* ptrs;
471+ cudaStreamCaptureStatus status;
472+ CUDACHECK (cudaStreamIsCapturing (stream, &status));
473+ if (status == cudaStreamCaptureStatusActive) {
474+ ptrs = d_rank_data_base_ + graph_unreg_buffers_.size ();
475+ graph_unreg_buffers_.push_back (input);
476+ } else {
477+ auto it = buffers_.find (input);
478+ if (it == buffers_.end ())
479+ throw std::runtime_error (
480+ " buffer address " +
481+ std::to_string (reinterpret_cast <uint64_t >(input)) +
482+ " is not registered!" );
483+ ptrs = it->second ;
484+ }
510485
511- size /= d;
512- auto bytes = size * sizeof (typename packed_t <T>::P);
513- int blocks = std::min (block_limit, (size + threads - 1 ) / threads);
486+ size /= d;
487+ auto bytes = size * sizeof (typename packed_t <T>::P);
488+ int blocks = std::min (block_limit, (size + threads - 1 ) / threads);
514489#define KL (ngpus, name ) \
515490 name<T, ngpus><<<blocks, threads, 0 , stream>>> (ptrs, sg_, self_sg_, output, \
516491 rank_, size);
@@ -529,27 +504,27 @@ class CustomAllreduce {
529504 break ; \
530505 }
531506
532- switch (world_size_) {
533- REDUCE_CASE (2 )
534- REDUCE_CASE (4 )
535- REDUCE_CASE (6 )
536- REDUCE_CASE (8 )
537- default :
538- throw std::runtime_error (
539- " custom allreduce only supports num gpus in (2,4,6,8). Actual num "
540- " gpus = " +
541- std::to_string (world_size_));
542- }
507+ switch (world_size_) {
508+ REDUCE_CASE (2 )
509+ REDUCE_CASE (4 )
510+ REDUCE_CASE (6 )
511+ REDUCE_CASE (8 )
512+ default :
513+ throw std::runtime_error (
514+ " custom allreduce only supports num gpus in (2,4,6,8). Actual num "
515+ " gpus = " +
516+ std::to_string (world_size_));
517+ }
543518#undef REDUCE_CASE
544519#undef KL
545- }
520+ }
546521
547- ~CustomAllreduce () {
548- for (auto [_, ptr] : ipc_handles_) {
549- CUDACHECK (cudaIpcCloseMemHandle (ptr));
522+ ~CustomAllreduce () {
523+ for (auto [_, ptr] : ipc_handles_) {
524+ CUDACHECK (cudaIpcCloseMemHandle (ptr));
525+ }
550526 }
551- }
552- }; // namespace vllm
527+ };
553528/* *
554529 * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
555530 a template instantiation:
0 commit comments