@@ -43,7 +43,12 @@ struct __align__(16) RankData { const void* ptrs[8]; };
4343struct __align__ (16 ) RankData { const void * __restrict__ ptrs[8 ]; };
4444#endif
4545
46- struct __align__ (16 ) RankSignals { volatile Signal* signals[8 ]; };
46+ struct __align__ (16 ) RankSignals {
47+ #ifndef USE_ROCM
48+ volatile
49+ #endif
50+ Signal* signals[8 ];
51+ };
4752
4853// like std::array, but aligned
4954template <typename T, int sz>
@@ -136,25 +141,29 @@ DINLINE O downcast(array_t<float, O::size> val) {
136141// This function is meant to be used as the first synchronization in the all
137142// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
138143// prior memory accesses. Note: volatile writes will not be reordered against
139- // other volatile writes.
144+ // other volatile writes (CUDA-only) .
140145template <int ngpus>
141- DINLINE void start_sync (const RankSignals& sg, volatile Signal* self_sg,
142- int rank) {
143146#ifdef USE_ROCM
147+ DINLINE void start_sync (const RankSignals& sg, Signal* self_sg, int rank) {
144148 uint32_t flag = self_sg->_flag [blockIdx .x ] + 1 ;
145149 if (threadIdx .x < ngpus) {
150+ __scoped_atomic_store_n (&self_sg->end [blockIdx .x ][threadIdx .x ], 0 ,
151+ __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
146152 // simultaneously write to the corresponding flag of all ranks.
147153 // Latency = 1 p2p write
148- __atomic_store_n (&sg.signals [threadIdx .x ]->start [blockIdx .x ][rank], flag,
149- __ATOMIC_RELAXED);
154+ __scoped_atomic_store_n (&sg.signals [threadIdx .x ]->start [blockIdx .x ][rank],
155+ 1 , __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
156+ __atomic_thread_fence (__ATOMIC_ACQ_REL);
150157 // wait until we got true from all ranks
151- while (__atomic_load_n (&self_sg->start [blockIdx .x ][threadIdx .x ],
152- __ATOMIC_RELAXED) < flag );
158+ while (! __scoped_atomic_load_n (&self_sg->start [blockIdx .x ][threadIdx .x ],
159+ __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) );
153160 }
154161 __syncthreads ();
155- // use one thread to update flag
156- if (threadIdx .x == 0 ) self_sg->_flag [blockIdx .x ] = flag;
162+ }
157163#else
164+ DINLINE void start_sync (const RankSignals& sg, volatile Signal* self_sg,
165+ int rank) {
166+ uint32_t flag = self_sg->_flag [blockIdx .x ] + 1 ;
158167 if (threadIdx .x < ngpus) {
159168 // reset flag for next time
160169 self_sg->end [blockIdx .x ][threadIdx .x ] = 0 ;
@@ -165,36 +174,38 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
165174 while (!self_sg->start [blockIdx .x ][threadIdx .x ]);
166175 }
167176 __syncthreads ();
168- #endif
169177}
178+ #endif
170179
171180// This function is meant to be used as the second or the final synchronization
172181// barrier in the all reduce kernel. If it's the final synchronization barrier,
173182// we don't need to make any visibility guarantees for prior memory accesses.
174183template <int ngpus, bool final_sync = false >
175- DINLINE void end_sync (const RankSignals& sg, volatile Signal* self_sg,
176- int rank) {
177184#ifdef USE_ROCM
185+ DINLINE void end_sync (const RankSignals& sg, Signal* self_sg, int rank) {
178186 __syncthreads ();
179187 // eliminate the case that prior writes are not visible after signals become
180188 // visible. Note that I did not managed to make this happen through a lot of
181189 // testing. Might be the case that hardware provides stronger guarantee than
182190 // the memory model.
183- uint32_t flag = self_sg->_flag [blockIdx .x ] + 1 ;
184191 if (threadIdx .x < ngpus) {
192+ // reset flag for next time
193+ __scoped_atomic_store_n (&self_sg->start [blockIdx .x ][threadIdx .x ], 0 ,
194+ __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
185195 // simultaneously write to the corresponding flag of all ranks.
186196 // Latency = 1 p2p write
187- __atomic_store_n (&sg.signals [threadIdx .x ]->end [blockIdx .x ][rank], flag,
188- final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
197+ __scoped_atomic_store_n (&sg.signals [threadIdx .x ]->end [blockIdx .x ][rank], 1 ,
198+ __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
199+ __atomic_thread_fence (__ATOMIC_ACQ_REL);
189200 // wait until we got true from all ranks
190- while (__atomic_load_n (&self_sg->end [blockIdx .x ][threadIdx .x ],
191- final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
192- flag);
201+ while (!__scoped_atomic_load_n (&self_sg->end [blockIdx .x ][threadIdx .x ],
202+ __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE));
193203 }
194- __syncthreads ();
195- // use one thread to update flag
196- if (threadIdx .x == 0 ) self_sg->_flag [blockIdx .x ] = flag;
204+ if constexpr (!final_sync) __syncthreads ();
205+ }
197206#else
207+ DINLINE void end_sync (const RankSignals& sg, volatile Signal* self_sg,
208+ int rank) {
198209 __syncthreads ();
199210 // eliminate the case that prior writes are not visible after signals become
200211 // visible. Note that I did not managed to make this happen through a lot of
@@ -211,8 +222,8 @@ DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
211222 while (!self_sg->end [blockIdx .x ][threadIdx .x ]);
212223 }
213224 if constexpr (!final_sync) __syncthreads ();
214- #endif
215225}
226+ #endif
216227
217228template <typename P, int ngpus, typename A>
218229DINLINE P packed_reduce (const P* ptrs[], int idx) {
@@ -227,8 +238,11 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
227238template <typename T, int ngpus>
228239__global__ void __launch_bounds__ (512 , 1 )
229240 cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
230- volatile Signal* self_sg, T* __restrict__ result,
231- int rank, int size) {
241+ #ifndef USE_ROCM
242+ volatile
243+ #endif
244+ Signal* self_sg,
245+ T* __restrict__ result, int rank, int size) {
232246 using P = typename packed_t <T>::P;
233247 using A = typename packed_t <T>::A;
234248 // note: we don't reorder the address so the accumulation order is the same
@@ -244,15 +258,22 @@ __global__ void __launch_bounds__(512, 1)
244258}
245259
246260template <typename P>
247- DINLINE P* get_tmp_buf (volatile Signal* sg) {
261+ DINLINE P* get_tmp_buf (
262+ #ifndef USE_ROCM
263+ volatile
264+ #endif
265+ Signal* sg) {
248266 return (P*)(((Signal*)sg) + 1 );
249267}
250268
251269template <typename T, int ngpus>
252270__global__ void __launch_bounds__ (512 , 1 )
253271 cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
254- volatile Signal* self_sg, T* __restrict__ result,
255- int rank, int size) {
272+ #ifndef USE_ROCM
273+ volatile
274+ #endif
275+ Signal* self_sg,
276+ T* __restrict__ result, int rank, int size) {
256277 int tid = blockIdx .x * blockDim .x + threadIdx .x ;
257278 int stride = gridDim .x * blockDim .x ;
258279 using P = typename packed_t <T>::P;
@@ -455,37 +476,41 @@ class CustomAllreduce {
455476 */
456477 template <typename T>
457478 void allreduce (cudaStream_t stream, T* input, T* output, int size,
479+ #ifdef USE_ROCM
480+ int threads = 512 , int block_limit = 18 ){
481+ #else
458482 int threads = 512 , int block_limit = 36 ) {
459- auto d = packed_t <T>::P::size;
460- if (size % d != 0 )
483+ #endif
484+ auto d = packed_t <T>::P::size;
485+ if (size % d != 0 )
486+ throw std::runtime_error (
487+ " custom allreduce currently requires input length to be multiple "
488+ " of " +
489+ std::to_string (d));
490+ if (block_limit > kMaxBlocks )
491+ throw std::runtime_error (" max supported block limit is " +
492+ std::to_string (kMaxBlocks ) + " . Got " +
493+ std::to_string (block_limit));
494+
495+ RankData* ptrs;
496+ cudaStreamCaptureStatus status;
497+ CUDACHECK (cudaStreamIsCapturing (stream, &status));
498+ if (status == cudaStreamCaptureStatusActive) {
499+ ptrs = d_rank_data_base_ + graph_unreg_buffers_.size ();
500+ graph_unreg_buffers_.push_back (input);
501+ } else {
502+ auto it = buffers_.find (input);
503+ if (it == buffers_.end ())
461504 throw std::runtime_error (
462- " custom allreduce currently requires input length to be multiple "
463- " of " +
464- std::to_string (d));
465- if (block_limit > kMaxBlocks )
466- throw std::runtime_error (" max supported block limit is " +
467- std::to_string (kMaxBlocks ) + " . Got " +
468- std::to_string (block_limit));
469-
470- RankData* ptrs;
471- cudaStreamCaptureStatus status;
472- CUDACHECK (cudaStreamIsCapturing (stream, &status));
473- if (status == cudaStreamCaptureStatusActive) {
474- ptrs = d_rank_data_base_ + graph_unreg_buffers_.size ();
475- graph_unreg_buffers_.push_back (input);
476- } else {
477- auto it = buffers_.find (input);
478- if (it == buffers_.end ())
479- throw std::runtime_error (
480- " buffer address " +
481- std::to_string (reinterpret_cast <uint64_t >(input)) +
482- " is not registered!" );
483- ptrs = it->second ;
484- }
505+ " buffer address " +
506+ std::to_string (reinterpret_cast <uint64_t >(input)) +
507+ " is not registered!" );
508+ ptrs = it->second ;
509+ }
485510
486- size /= d;
487- auto bytes = size * sizeof (typename packed_t <T>::P);
488- int blocks = std::min (block_limit, (size + threads - 1 ) / threads);
511+ size /= d;
512+ auto bytes = size * sizeof (typename packed_t <T>::P);
513+ int blocks = std::min (block_limit, (size + threads - 1 ) / threads);
489514#define KL (ngpus, name ) \
490515 name<T, ngpus><<<blocks, threads, 0 , stream>>> (ptrs, sg_, self_sg_, output, \
491516 rank_, size);
@@ -504,27 +529,27 @@ class CustomAllreduce {
504529 break ; \
505530 }
506531
507- switch (world_size_) {
508- REDUCE_CASE (2 )
509- REDUCE_CASE (4 )
510- REDUCE_CASE (6 )
511- REDUCE_CASE (8 )
512- default :
513- throw std::runtime_error (
514- " custom allreduce only supports num gpus in (2,4,6,8). Actual num "
515- " gpus = " +
516- std::to_string (world_size_));
517- }
532+ switch (world_size_) {
533+ REDUCE_CASE (2 )
534+ REDUCE_CASE (4 )
535+ REDUCE_CASE (6 )
536+ REDUCE_CASE (8 )
537+ default :
538+ throw std::runtime_error (
539+ " custom allreduce only supports num gpus in (2,4,6,8). Actual num "
540+ " gpus = " +
541+ std::to_string (world_size_));
542+ }
518543#undef REDUCE_CASE
519544#undef KL
520- }
545+ }
521546
522- ~CustomAllreduce () {
523- for (auto [_, ptr] : ipc_handles_) {
524- CUDACHECK (cudaIpcCloseMemHandle (ptr));
525- }
547+ ~CustomAllreduce () {
548+ for (auto [_, ptr] : ipc_handles_) {
549+ CUDACHECK (cudaIpcCloseMemHandle (ptr));
526550 }
527- };
551+ }
552+ }; // namespace vllm
528553/* *
529554 * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
530555 a template instantiation:
0 commit comments