-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[OpenMP][OMPX] Add shfl_down_sync #93311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-offload Author: Shilei Tian (shiltian) ChangesFull diff: https://github.com/llvm/llvm-project/pull/93311.diff 5 Files Affected:
diff --git a/offload/DeviceRTL/include/Utils.h b/offload/DeviceRTL/include/Utils.h
index d43b7f5c95de1..82e2397b5958b 100644
--- a/offload/DeviceRTL/include/Utils.h
+++ b/offload/DeviceRTL/include/Utils.h
@@ -25,6 +25,8 @@ int32_t shuffle(uint64_t Mask, int32_t Var, int32_t SrcLane);
int32_t shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta, int32_t Width);
+int64_t shuffleDown(uint64_t Mask, int64_t Var, uint32_t Delta, int32_t Width);
+
uint64_t ballotSync(uint64_t Mask, int32_t Pred);
/// Return \p LowBits and \p HighBits packed into a single 64 bit value.
diff --git a/offload/DeviceRTL/src/Mapping.cpp b/offload/DeviceRTL/src/Mapping.cpp
index 4f39d2a299ee6..c1ce878746a69 100644
--- a/offload/DeviceRTL/src/Mapping.cpp
+++ b/offload/DeviceRTL/src/Mapping.cpp
@@ -364,8 +364,30 @@ _TGT_KERNEL_LANGUAGE(block_id, getBlockIdInKernel)
_TGT_KERNEL_LANGUAGE(block_dim, getNumberOfThreadsInBlock)
_TGT_KERNEL_LANGUAGE(grid_dim, getNumberOfBlocksInKernel)
-extern "C" uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
+extern "C" {
+uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
return utils::ballotSync(mask, pred);
}
+int ompx_shfl_down_sync_i(uint64_t mask, int var, unsigned delta, int width) {
+ return utils::shuffleDown(mask, var, delta, width);
+}
+
+float ompx_shfl_down_sync_f(uint64_t mask, float var, unsigned delta,
+ int width) {
+ return utils::convertViaPun<float>(utils::shuffleDown(
+ mask, utils::convertViaPun<int32_t>(var), delta, width));
+}
+
+long ompx_shfl_down_sync_l(uint64_t mask, long var, unsigned delta, int width) {
+ return utils::shuffleDown(mask, var, delta, width);
+}
+
+double ompx_shfl_down_sync_d(uint64_t mask, double var, unsigned delta,
+ int width) {
+ return utils::convertViaPun<double>(utils::shuffleDown(
+ mask, utils::convertViaPun<int64_t>(var), delta, width));
+}
+}
+
#pragma omp end declare target
diff --git a/offload/DeviceRTL/src/Utils.cpp b/offload/DeviceRTL/src/Utils.cpp
index 606e3bec0d33c..4793e0b28df8c 100644
--- a/offload/DeviceRTL/src/Utils.cpp
+++ b/offload/DeviceRTL/src/Utils.cpp
@@ -113,6 +113,15 @@ int32_t utils::shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta,
return impl::shuffleDown(Mask, Var, Delta, Width);
}
+int64_t utils::shuffleDown(uint64_t Mask, int64_t Var, uint32_t Delta,
+ int32_t Width) {
+ uint32_t lo, hi;
+ utils::unpack(Var, lo, hi);
+ hi = impl::shuffleDown(Mask, hi, Delta, Width);
+ lo = impl::shuffleDown(Mask, lo, Delta, Width);
+ return utils::pack(lo, hi);
+}
+
uint64_t utils::ballotSync(uint64_t Mask, int32_t Pred) {
return impl::ballotSync(Mask, Pred);
}
@@ -125,11 +134,7 @@ int32_t __kmpc_shuffle_int32(int32_t Val, int16_t Delta, int16_t SrcLane) {
}
int64_t __kmpc_shuffle_int64(int64_t Val, int16_t Delta, int16_t Width) {
- uint32_t lo, hi;
- utils::unpack(Val, lo, hi);
- hi = impl::shuffleDown(lanes::All, hi, Delta, Width);
- lo = impl::shuffleDown(lanes::All, lo, Delta, Width);
- return utils::pack(lo, hi);
+ return utils::shuffleDown(lanes::All, Val, Delta, Width);
}
}
diff --git a/offload/test/offloading/ompx_bare_shfl_down_sync.cpp b/offload/test/offloading/ompx_bare_shfl_down_sync.cpp
new file mode 100644
index 0000000000000..fecc214176e97
--- /dev/null
+++ b/offload/test/offloading/ompx_bare_shfl_down_sync.cpp
@@ -0,0 +1,68 @@
+// RUN: %libomptarget-compilexx-run-and-check-generic
+//
+// UNSUPPORTED: x86_64-pc-linux-gnu
+// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: aarch64-unknown-linux-gnu
+// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
+// UNSUPPORTED: s390x-ibm-linux-gnu
+// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
+
+#ifdef __AMDGCN_WAVEFRONT_SIZE
+#define WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
+#else
+#define WARP_SIZE 32
+#endif
+
+#include <cassert>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <limits>
+#include <ompx.h>
+#include <type_traits>
+
+template <typename T, std::enable_if_t<std::is_integral<T>::value, bool> = true>
+bool equal(T LHS, T RHS) {
+ return LHS == RHS;
+}
+
+template <typename T,
+ std::enable_if_t<std::is_floating_point<T>::value, bool> = true>
+bool equal(T LHS, T RHS) {
+ return std::abs(LHS - RHS) < std::numeric_limits<T>::epsilon();
+}
+
+template <typename T> void test() {
+ constexpr const int num_blocks = 1;
+ constexpr const int block_size = 256;
+ constexpr const int N = num_blocks * block_size;
+ T *data = new T[N];
+
+ for (int i = 0; i < N; ++i)
+ data[i] = i;
+
+#pragma omp target teams ompx_bare num_teams(num_blocks) thread_limit(block_size) map(tofrom : data[0 : N])
+ {
+ int tid = ompx_thread_id_x();
+ data[tid] = ompx::shfl_down_sync(~0U, data[tid], 1);
+ }
+
+ for (int i = N - 1; i > 0; i -= WARP_SIZE) {
+ assert(equal(data[i], static_cast<T>(i)));
+ for (int j = i; j > i - WARP_SIZE; --j)
+ assert(equal(data[i], data[i - 1]));
+ }
+
+ delete[] data;
+}
+
+int main(int argc, char *argv[]) {
+ test<int32_t>();
+ test<int64_t>();
+ test<float>();
+ test<double>();
+ // CHECK: PASS
+ printf("PASS\n");
+
+ return 0;
+}
diff --git a/openmp/runtime/src/include/ompx.h.var b/openmp/runtime/src/include/ompx.h.var
index 19851880c3ac3..7f41d6ef92219 100644
--- a/openmp/runtime/src/include/ompx.h.var
+++ b/openmp/runtime/src/include/ompx.h.var
@@ -9,6 +9,12 @@
#ifndef __OMPX_H
#define __OMPX_H
+#ifdef __AMDGCN_WAVEFRONT_SIZE
+#define __WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
+#else
+#define __WARP_SIZE 32
+#endif
+
typedef unsigned long uint64_t;
#ifdef __cplusplus
@@ -87,6 +93,22 @@ static inline uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
__builtin_trap();
}
+/// ompx_shfl_down_sync_{i,f,l,d}
+///{
+#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(TYPE, TY) \
+ static inline TYPE ompx_shfl_down_sync_##TY(uint64_t mask, TYPE var, \
+ unsigned delta, int width) { \
+ __builtin_trap(); \
+ }
+
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(int, i)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(float, f)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(long, l)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(double, d)
+
+#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL
+///}
+
#pragma omp end declare variant
/// ompx_{sync_block}_{,divergent}
@@ -117,6 +139,20 @@ _TGT_KERNEL_LANGUAGE_DECL_GRID_C(grid_dim)
uint64_t ompx_ballot_sync(uint64_t mask, int pred);
+/// ompx_shfl_down_sync_{i,f,l,d}
+///{
+#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(TYPE, TY) \
+ TYPE ompx_shfl_down_sync_##TY(uint64_t mask, TYPE var, unsigned delta, \
+ int width = __WARP_SIZE);
+
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(int, i)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(float, f)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(long, l)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(double, d)
+
+#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC
+///}
+
#ifdef __cplusplus
}
#endif
@@ -172,6 +208,22 @@ static inline uint64_t ballot_sync(uint64_t mask, int pred) {
return ompx_ballot_sync(mask, pred);
}
+/// shfl_down_sync
+///{
+#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(TYPE, TY) \
+ static inline TYPE shfl_down_sync(uint64_t mask, TYPE var, unsigned delta, \
+ int width = __WARP_SIZE) { \
+ return ompx_shfl_down_sync_##TY(mask, var, delta, width); \
+ }
+
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(int, i)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(float, f)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(long, l)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(double, d)
+
+#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC
+///}
+
} // namespace ompx
#endif
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
41ca64c
to
9c59d0c
Compare
jhuber6
approved these changes
May 24, 2024
jhuber6
added a commit
that referenced
this pull request
May 25, 2024
shiltian
added a commit
to shiltian/llvm-project
that referenced
this pull request
Jun 2, 2024
shiltian
added a commit
to shiltian/llvm-project
that referenced
this pull request
Jun 2, 2024
shiltian
added a commit
that referenced
this pull request
Jun 3, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.