Skip to content

Commit 4fb02de

Browse files
authored
[OpenMP][OMPX] Add shfl_down_sync (#93311)
1 parent d07362f commit 4fb02de

File tree

5 files changed

+154
-6
lines changed

5 files changed

+154
-6
lines changed

offload/DeviceRTL/include/Utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ int32_t shuffle(uint64_t Mask, int32_t Var, int32_t SrcLane);
2525

2626
int32_t shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta, int32_t Width);
2727

28+
int64_t shuffleDown(uint64_t Mask, int64_t Var, uint32_t Delta, int32_t Width);
29+
2830
uint64_t ballotSync(uint64_t Mask, int32_t Pred);
2931

3032
/// Return \p LowBits and \p HighBits packed into a single 64 bit value.

offload/DeviceRTL/src/Mapping.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,30 @@ _TGT_KERNEL_LANGUAGE(block_id, getBlockIdInKernel)
364364
_TGT_KERNEL_LANGUAGE(block_dim, getNumberOfThreadsInBlock)
365365
_TGT_KERNEL_LANGUAGE(grid_dim, getNumberOfBlocksInKernel)
366366

367-
extern "C" uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
367+
extern "C" {
368+
uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
368369
return utils::ballotSync(mask, pred);
369370
}
370371

372+
int ompx_shfl_down_sync_i(uint64_t mask, int var, unsigned delta, int width) {
373+
return utils::shuffleDown(mask, var, delta, width);
374+
}
375+
376+
float ompx_shfl_down_sync_f(uint64_t mask, float var, unsigned delta,
377+
int width) {
378+
return utils::convertViaPun<float>(utils::shuffleDown(
379+
mask, utils::convertViaPun<int32_t>(var), delta, width));
380+
}
381+
382+
long ompx_shfl_down_sync_l(uint64_t mask, long var, unsigned delta, int width) {
383+
return utils::shuffleDown(mask, var, delta, width);
384+
}
385+
386+
double ompx_shfl_down_sync_d(uint64_t mask, double var, unsigned delta,
387+
int width) {
388+
return utils::convertViaPun<double>(utils::shuffleDown(
389+
mask, utils::convertViaPun<int64_t>(var), delta, width));
390+
}
391+
}
392+
371393
#pragma omp end declare target

offload/DeviceRTL/src/Utils.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,15 @@ int32_t utils::shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta,
113113
return impl::shuffleDown(Mask, Var, Delta, Width);
114114
}
115115

116+
int64_t utils::shuffleDown(uint64_t Mask, int64_t Var, uint32_t Delta,
117+
int32_t Width) {
118+
uint32_t Lo, Hi;
119+
utils::unpack(Var, Lo, Hi);
120+
Hi = impl::shuffleDown(Mask, Hi, Delta, Width);
121+
Lo = impl::shuffleDown(Mask, Lo, Delta, Width);
122+
return utils::pack(Lo, Hi);
123+
}
124+
116125
uint64_t utils::ballotSync(uint64_t Mask, int32_t Pred) {
117126
return impl::ballotSync(Mask, Pred);
118127
}
@@ -125,11 +134,7 @@ int32_t __kmpc_shuffle_int32(int32_t Val, int16_t Delta, int16_t SrcLane) {
125134
}
126135

127136
int64_t __kmpc_shuffle_int64(int64_t Val, int16_t Delta, int16_t Width) {
128-
uint32_t lo, hi;
129-
utils::unpack(Val, lo, hi);
130-
hi = impl::shuffleDown(lanes::All, hi, Delta, Width);
131-
lo = impl::shuffleDown(lanes::All, lo, Delta, Width);
132-
return utils::pack(lo, hi);
137+
return utils::shuffleDown(lanes::All, Val, Delta, Width);
133138
}
134139
}
135140

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// RUN: %libomptarget-compilexx-run-and-check-generic
2+
//
3+
// UNSUPPORTED: x86_64-pc-linux-gnu
4+
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
5+
// UNSUPPORTED: aarch64-unknown-linux-gnu
6+
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
7+
// UNSUPPORTED: s390x-ibm-linux-gnu
8+
// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
9+
10+
#ifdef __AMDGCN_WAVEFRONT_SIZE
11+
#define WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
12+
#else
13+
#define WARP_SIZE 32
14+
#endif
15+
16+
#include <cassert>
17+
#include <cmath>
18+
#include <cstdint>
19+
#include <cstdio>
20+
#include <limits>
21+
#include <ompx.h>
22+
#include <type_traits>
23+
24+
template <typename T, std::enable_if_t<std::is_integral<T>::value, bool> = true>
25+
bool equal(T LHS, T RHS) {
26+
return LHS == RHS;
27+
}
28+
29+
template <typename T,
30+
std::enable_if_t<std::is_floating_point<T>::value, bool> = true>
31+
bool equal(T LHS, T RHS) {
32+
return std::abs(LHS - RHS) < std::numeric_limits<T>::epsilon();
33+
}
34+
35+
template <typename T> void test() {
36+
constexpr const int num_blocks = 1;
37+
constexpr const int block_size = 256;
38+
constexpr const int N = num_blocks * block_size;
39+
T *data = new T[N];
40+
41+
for (int i = 0; i < N; ++i)
42+
data[i] = i;
43+
44+
#pragma omp target teams ompx_bare num_teams(num_blocks) \
45+
thread_limit(block_size) map(tofrom : data[0 : N])
46+
{
47+
int tid = ompx_thread_id_x();
48+
data[tid] = ompx::shfl_down_sync(~0U, data[tid], 1);
49+
}
50+
51+
for (int i = N - 1; i > 0; i -= WARP_SIZE)
52+
for (int j = i; j > i - WARP_SIZE; --j)
53+
assert(equal(data[i], data[i - 1]));
54+
55+
delete[] data;
56+
}
57+
58+
int main(int argc, char *argv[]) {
59+
test<int32_t>();
60+
test<int64_t>();
61+
test<float>();
62+
test<double>();
63+
// CHECK: PASS
64+
printf("PASS\n");
65+
66+
return 0;
67+
}

openmp/runtime/src/include/ompx.h.var

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
#ifndef __OMPX_H
1010
#define __OMPX_H
1111

12+
#ifdef __AMDGCN_WAVEFRONT_SIZE
13+
#define __WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
14+
#else
15+
#define __WARP_SIZE 32
16+
#endif
17+
1218
typedef unsigned long uint64_t;
1319

1420
#ifdef __cplusplus
@@ -87,6 +93,22 @@ static inline uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
8793
__builtin_trap();
8894
}
8995

96+
/// ompx_shfl_down_sync_{i,f,l,d}
97+
///{
98+
#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(TYPE, TY) \
99+
static inline TYPE ompx_shfl_down_sync_##TY(uint64_t mask, TYPE var, \
100+
unsigned delta, int width) { \
101+
__builtin_trap(); \
102+
}
103+
104+
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(int, i)
105+
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(float, f)
106+
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(long, l)
107+
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(double, d)
108+
109+
#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL
110+
///}
111+
90112
#pragma omp end declare variant
91113

92114
/// ompx_{sync_block}_{,divergent}
@@ -117,6 +139,20 @@ _TGT_KERNEL_LANGUAGE_DECL_GRID_C(grid_dim)
117139

118140
uint64_t ompx_ballot_sync(uint64_t mask, int pred);
119141

142+
/// ompx_shfl_down_sync_{i,f,l,d}
143+
///{
144+
#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(TYPE, TY) \
145+
TYPE ompx_shfl_down_sync_##TY(uint64_t mask, TYPE var, unsigned delta, \
146+
int width = __WARP_SIZE);
147+
148+
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(int, i)
149+
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(float, f)
150+
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(long, l)
151+
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(double, d)
152+
153+
#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC
154+
///}
155+
120156
#ifdef __cplusplus
121157
}
122158
#endif
@@ -172,6 +208,22 @@ static inline uint64_t ballot_sync(uint64_t mask, int pred) {
172208
return ompx_ballot_sync(mask, pred);
173209
}
174210

211+
/// shfl_down_sync
212+
///{
213+
#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(TYPE, TY) \
214+
static inline TYPE shfl_down_sync(uint64_t mask, TYPE var, unsigned delta, \
215+
int width = __WARP_SIZE) { \
216+
return ompx_shfl_down_sync_##TY(mask, var, delta, width); \
217+
}
218+
219+
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(int, i)
220+
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(float, f)
221+
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(long, l)
222+
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(double, d)
223+
224+
#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC
225+
///}
226+
175227
} // namespace ompx
176228
#endif
177229

0 commit comments

Comments
 (0)