Skip to content

[OpenMP][OMPX] Add shfl_down_sync #93311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 24, 2024
Merged

[OpenMP][OMPX] Add shfl_down_sync #93311

merged 1 commit into from
May 24, 2024

Conversation

shiltian
Copy link
Contributor

No description provided.

@shiltian shiltian requested review from jdoerfert and jhuber6 May 24, 2024 15:12
@llvmbot llvmbot added openmp:libomp OpenMP host runtime offload labels May 24, 2024
@llvmbot
Copy link
Member

llvmbot commented May 24, 2024

@llvm/pr-subscribers-offload

Author: Shilei Tian (shiltian)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/93311.diff

5 Files Affected:

  • (modified) offload/DeviceRTL/include/Utils.h (+2)
  • (modified) offload/DeviceRTL/src/Mapping.cpp (+23-1)
  • (modified) offload/DeviceRTL/src/Utils.cpp (+10-5)
  • (added) offload/test/offloading/ompx_bare_shfl_down_sync.cpp (+68)
  • (modified) openmp/runtime/src/include/ompx.h.var (+52)
diff --git a/offload/DeviceRTL/include/Utils.h b/offload/DeviceRTL/include/Utils.h
index d43b7f5c95de1..82e2397b5958b 100644
--- a/offload/DeviceRTL/include/Utils.h
+++ b/offload/DeviceRTL/include/Utils.h
@@ -25,6 +25,8 @@ int32_t shuffle(uint64_t Mask, int32_t Var, int32_t SrcLane);
 
 int32_t shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta, int32_t Width);
 
+int64_t shuffleDown(uint64_t Mask, int64_t Var, uint32_t Delta, int32_t Width);
+
 uint64_t ballotSync(uint64_t Mask, int32_t Pred);
 
 /// Return \p LowBits and \p HighBits packed into a single 64 bit value.
diff --git a/offload/DeviceRTL/src/Mapping.cpp b/offload/DeviceRTL/src/Mapping.cpp
index 4f39d2a299ee6..c1ce878746a69 100644
--- a/offload/DeviceRTL/src/Mapping.cpp
+++ b/offload/DeviceRTL/src/Mapping.cpp
@@ -364,8 +364,30 @@ _TGT_KERNEL_LANGUAGE(block_id, getBlockIdInKernel)
 _TGT_KERNEL_LANGUAGE(block_dim, getNumberOfThreadsInBlock)
 _TGT_KERNEL_LANGUAGE(grid_dim, getNumberOfBlocksInKernel)
 
-extern "C" uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
+extern "C" {
+uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
   return utils::ballotSync(mask, pred);
 }
 
+int ompx_shfl_down_sync_i(uint64_t mask, int var, unsigned delta, int width) {
+  return utils::shuffleDown(mask, var, delta, width);
+}
+
+float ompx_shfl_down_sync_f(uint64_t mask, float var, unsigned delta,
+                            int width) {
+  return utils::convertViaPun<float>(utils::shuffleDown(
+      mask, utils::convertViaPun<int32_t>(var), delta, width));
+}
+
+long ompx_shfl_down_sync_l(uint64_t mask, long var, unsigned delta, int width) {
+  return utils::shuffleDown(mask, var, delta, width);
+}
+
+double ompx_shfl_down_sync_d(uint64_t mask, double var, unsigned delta,
+                             int width) {
+  return utils::convertViaPun<double>(utils::shuffleDown(
+      mask, utils::convertViaPun<int64_t>(var), delta, width));
+}
+}
+
 #pragma omp end declare target
diff --git a/offload/DeviceRTL/src/Utils.cpp b/offload/DeviceRTL/src/Utils.cpp
index 606e3bec0d33c..4793e0b28df8c 100644
--- a/offload/DeviceRTL/src/Utils.cpp
+++ b/offload/DeviceRTL/src/Utils.cpp
@@ -113,6 +113,15 @@ int32_t utils::shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta,
   return impl::shuffleDown(Mask, Var, Delta, Width);
 }
 
+int64_t utils::shuffleDown(uint64_t Mask, int64_t Var, uint32_t Delta,
+                           int32_t Width) {
+  uint32_t lo, hi;
+  utils::unpack(Var, lo, hi);
+  hi = impl::shuffleDown(Mask, hi, Delta, Width);
+  lo = impl::shuffleDown(Mask, lo, Delta, Width);
+  return utils::pack(lo, hi);
+}
+
 uint64_t utils::ballotSync(uint64_t Mask, int32_t Pred) {
   return impl::ballotSync(Mask, Pred);
 }
@@ -125,11 +134,7 @@ int32_t __kmpc_shuffle_int32(int32_t Val, int16_t Delta, int16_t SrcLane) {
 }
 
 int64_t __kmpc_shuffle_int64(int64_t Val, int16_t Delta, int16_t Width) {
-  uint32_t lo, hi;
-  utils::unpack(Val, lo, hi);
-  hi = impl::shuffleDown(lanes::All, hi, Delta, Width);
-  lo = impl::shuffleDown(lanes::All, lo, Delta, Width);
-  return utils::pack(lo, hi);
+  return utils::shuffleDown(lanes::All, Val, Delta, Width);
 }
 }
 
diff --git a/offload/test/offloading/ompx_bare_shfl_down_sync.cpp b/offload/test/offloading/ompx_bare_shfl_down_sync.cpp
new file mode 100644
index 0000000000000..fecc214176e97
--- /dev/null
+++ b/offload/test/offloading/ompx_bare_shfl_down_sync.cpp
@@ -0,0 +1,68 @@
+// RUN: %libomptarget-compilexx-run-and-check-generic
+//
+// UNSUPPORTED: x86_64-pc-linux-gnu
+// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: aarch64-unknown-linux-gnu
+// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
+// UNSUPPORTED: s390x-ibm-linux-gnu
+// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
+
+#ifdef __AMDGCN_WAVEFRONT_SIZE
+#define WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
+#else
+#define WARP_SIZE 32
+#endif
+
+#include <cassert>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <limits>
+#include <ompx.h>
+#include <type_traits>
+
+template <typename T, std::enable_if_t<std::is_integral<T>::value, bool> = true>
+bool equal(T LHS, T RHS) {
+  return LHS == RHS;
+}
+
+template <typename T,
+          std::enable_if_t<std::is_floating_point<T>::value, bool> = true>
+bool equal(T LHS, T RHS) {
+  return std::abs(LHS - RHS) < std::numeric_limits<T>::epsilon();
+}
+
+template <typename T> void test() {
+  constexpr const int num_blocks = 1;
+  constexpr const int block_size = 256;
+  constexpr const int N = num_blocks * block_size;
+  T *data = new T[N];
+
+  for (int i = 0; i < N; ++i)
+    data[i] = i;
+
+#pragma omp target teams ompx_bare num_teams(num_blocks) thread_limit(block_size) map(tofrom : data[0 : N])
+  {
+    int tid = ompx_thread_id_x();
+    data[tid] = ompx::shfl_down_sync(~0U, data[tid], 1);
+  }
+
+  for (int i = N - 1; i > 0; i -= WARP_SIZE) {
+    assert(equal(data[i], static_cast<T>(i)));
+    for (int j = i; j > i - WARP_SIZE; --j)
+      assert(equal(data[i], data[i - 1]));
+  }
+
+  delete[] data;
+}
+
+int main(int argc, char *argv[]) {
+  test<int32_t>();
+  test<int64_t>();
+  test<float>();
+  test<double>();
+  // CHECK: PASS
+  printf("PASS\n");
+
+  return 0;
+}
diff --git a/openmp/runtime/src/include/ompx.h.var b/openmp/runtime/src/include/ompx.h.var
index 19851880c3ac3..7f41d6ef92219 100644
--- a/openmp/runtime/src/include/ompx.h.var
+++ b/openmp/runtime/src/include/ompx.h.var
@@ -9,6 +9,12 @@
 #ifndef __OMPX_H
 #define __OMPX_H
 
+#ifdef __AMDGCN_WAVEFRONT_SIZE
+#define __WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
+#else
+#define __WARP_SIZE 32
+#endif
+
 typedef unsigned long uint64_t;
 
 #ifdef __cplusplus
@@ -87,6 +93,22 @@ static inline uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
   __builtin_trap();
 }
 
+/// ompx_shfl_down_sync_{i,f,l,d}
+///{
+#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(TYPE, TY)                \
+  static inline TYPE ompx_shfl_down_sync_##TY(uint64_t mask, TYPE var,         \
+                                              unsigned delta, int width) {     \
+    __builtin_trap();                                                          \
+  }
+
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(int, i)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(float, f)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(long, l)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(double, d)
+
+#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL
+///}
+
 #pragma omp end declare variant
 
 /// ompx_{sync_block}_{,divergent}
@@ -117,6 +139,20 @@ _TGT_KERNEL_LANGUAGE_DECL_GRID_C(grid_dim)
 
 uint64_t ompx_ballot_sync(uint64_t mask, int pred);
 
+/// ompx_shfl_down_sync_{i,f,l,d}
+///{
+#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(TYPE, TY)                          \
+  TYPE ompx_shfl_down_sync_##TY(uint64_t mask, TYPE var, unsigned delta,       \
+                                int width = __WARP_SIZE);
+
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(int, i)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(float, f)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(long, l)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(double, d)
+
+#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC
+///}
+
 #ifdef __cplusplus
 }
 #endif
@@ -172,6 +208,22 @@ static inline uint64_t ballot_sync(uint64_t mask, int pred) {
   return ompx_ballot_sync(mask, pred);
 }
 
+/// shfl_down_sync
+///{
+#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(TYPE, TY)                          \
+  static inline TYPE shfl_down_sync(uint64_t mask, TYPE var, unsigned delta,   \
+                                    int width = __WARP_SIZE) {                 \
+    return ompx_shfl_down_sync_##TY(mask, var, delta, width);                  \
+  }
+
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(int, i)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(float, f)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(long, l)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(double, d)
+
+#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC
+///}
+
 } // namespace ompx
 #endif
 

Copy link

github-actions bot commented May 24, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@shiltian shiltian force-pushed the shuffleDown branch 2 times, most recently from 41ca64c to 9c59d0c Compare May 24, 2024 15:19
@shiltian shiltian merged commit 4fb02de into llvm:main May 24, 2024
5 checks passed
@shiltian shiltian deleted the shuffleDown branch May 24, 2024 18:00
jhuber6 added a commit that referenced this pull request May 25, 2024
This reverts commit 098c6df.
This reverts commit 8c718a3.
This reverts commit 4fb02de.
shiltian added a commit that referenced this pull request May 26, 2024
shiltian added a commit that referenced this pull request May 26, 2024
shiltian added a commit to shiltian/llvm-project that referenced this pull request Jun 2, 2024
shiltian added a commit to shiltian/llvm-project that referenced this pull request Jun 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
offload openmp:libomp OpenMP host runtime
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants