Skip to content

[OpenMP] Make each atomic helper take an atomic scope argument #122786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 21, 2025

Conversation

jhuber6
Copy link
Contributor

@jhuber6 jhuber6 commented Jan 13, 2025

Summary:
Right now we just default to device for each type, and mix an ad-hoc
scope with the one used by the compiler's builtins. Unify this can make
each version take the scope optionally.

For @ronlieb, this will remove the need for add_system in the fork as
well as the extra cas with system scope, just pass system.

Summary:
Right now we just default to device for each type, and mix an ad-hoc
scope with the one used by the compiler's builtins. Unify this can make
each version take the scope optionally.

For @ronlieb, this will remove the need for `add_system` in the fork as
well as the extra `cas` with system scope, just pass `system`.
@llvmbot
Copy link
Member

llvmbot commented Jan 13, 2025

@llvm/pr-subscribers-offload

Author: Joseph Huber (jhuber6)

Changes

Summary:
Right now we just default to device for each type, and mix an ad-hoc
scope with the one used by the compiler's builtins. Unify this can make
each version take the scope optionally.

For @ronlieb, this will remove the need for add_system in the fork as
well as the extra cas with system scope, just pass system.


Full diff: https://github.com/llvm/llvm-project/pull/122786.diff

2 Files Affected:

  • (modified) offload/DeviceRTL/include/Synchronization.h (+61-60)
  • (modified) offload/DeviceRTL/src/Synchronization.cpp (+7-3)
diff --git a/offload/DeviceRTL/include/Synchronization.h b/offload/DeviceRTL/include/Synchronization.h
index e1968675550d49..8b478a8a8723f2 100644
--- a/offload/DeviceRTL/include/Synchronization.h
+++ b/offload/DeviceRTL/include/Synchronization.h
@@ -28,23 +28,17 @@ enum OrderingTy {
   seq_cst = __ATOMIC_SEQ_CST,
 };
 
-enum ScopeTy {
+enum MemScopeTy {
   system = __MEMORY_SCOPE_SYSTEM,
-  device_ = __MEMORY_SCOPE_DEVICE,
+  device = __MEMORY_SCOPE_DEVICE,
   workgroup = __MEMORY_SCOPE_WRKGRP,
   wavefront = __MEMORY_SCOPE_WVFRNT,
   single = __MEMORY_SCOPE_SINGLE,
 };
 
-enum MemScopeTy {
-  all,    // All threads on all devices
-  device, // All threads on the device
-  cgroup  // All threads in the contention group, e.g. the team
-};
-
 /// Atomically increment \p *Addr and wrap at \p V with \p Ordering semantics.
 uint32_t inc(uint32_t *Addr, uint32_t V, OrderingTy Ordering,
-             MemScopeTy MemScope = MemScopeTy::all);
+             MemScopeTy MemScope = MemScopeTy::device);
 
 /// Atomically perform <op> on \p V and \p *Addr with \p Ordering semantics. The
 /// result is stored in \p *Addr;
@@ -52,120 +46,127 @@ uint32_t inc(uint32_t *Addr, uint32_t V, OrderingTy Ordering,
 
 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
 bool cas(Ty *Address, V ExpectedV, V DesiredV, atomic::OrderingTy OrderingSucc,
-         atomic::OrderingTy OrderingFail) {
+         atomic::OrderingTy OrderingFail,
+         MemScopeTy MemScope = MemScopeTy::device) {
   return __scoped_atomic_compare_exchange(Address, &ExpectedV, &DesiredV, false,
-                                          OrderingSucc, OrderingFail,
-                                          __MEMORY_SCOPE_DEVICE);
+                                          OrderingSucc, OrderingFail, MemScope);
 }
 
 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
-V add(Ty *Address, V Val, atomic::OrderingTy Ordering) {
-  return __scoped_atomic_fetch_add(Address, Val, Ordering,
-                                   __MEMORY_SCOPE_DEVICE);
+V add(Ty *Address, V Val, atomic::OrderingTy Ordering,
+      MemScopeTy MemScope = MemScopeTy::device) {
+  return __scoped_atomic_fetch_add(Address, Val, Ordering, MemScope);
 }
 
 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
-V load(Ty *Address, atomic::OrderingTy Ordering) {
-  return add(Address, Ty(0), Ordering);
+V load(Ty *Address, atomic::OrderingTy Ordering,
+       MemScopeTy MemScope = MemScopeTy::device) {
+  return __scoped_atomic_load_n(Address, Ordering, MemScope);
 }
 
 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
-void store(Ty *Address, V Val, atomic::OrderingTy Ordering) {
-  __scoped_atomic_store_n(Address, Val, Ordering, __MEMORY_SCOPE_DEVICE);
+void store(Ty *Address, V Val, atomic::OrderingTy Ordering,
+           MemScopeTy MemScope = MemScopeTy::device) {
+  __scoped_atomic_store_n(Address, Val, Ordering, MemScope);
 }
 
 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
-V mul(Ty *Address, V Val, atomic::OrderingTy Ordering) {
+V mul(Ty *Address, V Val, atomic::OrderingTy Ordering,
+      MemScopeTy MemScope = MemScopeTy::device) {
   Ty TypedCurrentVal, TypedResultVal, TypedNewVal;
   bool Success;
   do {
     TypedCurrentVal = atomic::load(Address, Ordering);
     TypedNewVal = TypedCurrentVal * Val;
     Success = atomic::cas(Address, TypedCurrentVal, TypedNewVal, Ordering,
-                          atomic::relaxed);
+                          atomic::relaxed, MemScope);
   } while (!Success);
   return TypedResultVal;
 }
 
 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
 utils::enable_if_t<!utils::is_floating_point_v<V>, V>
-max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
-  return __scoped_atomic_fetch_max(Address, Val, Ordering,
-                                   __MEMORY_SCOPE_DEVICE);
+max(Ty *Address, V Val, atomic::OrderingTy Ordering,
+    MemScopeTy MemScope = MemScopeTy::device) {
+  return __scoped_atomic_fetch_max(Address, Val, Ordering, MemScope);
 }
 
 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
 utils::enable_if_t<utils::is_same_v<V, float>, V>
-max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
+max(Ty *Address, V Val, atomic::OrderingTy Ordering,
+    MemScopeTy MemScope = MemScopeTy::device) {
   if (Val >= 0)
-    return utils::bitCast<float>(
-        max((int32_t *)Address, utils::bitCast<int32_t>(Val), Ordering));
-  return utils::bitCast<float>(
-      min((uint32_t *)Address, utils::bitCast<uint32_t>(Val), Ordering));
+    return utils::bitCast<float>(max(
+        (int32_t *)Address, utils::bitCast<int32_t>(Val), Ordering, MemScope));
+  return utils::bitCast<float>(min(
+      (uint32_t *)Address, utils::bitCast<uint32_t>(Val), Ordering, MemScope));
 }
 
 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
 utils::enable_if_t<utils::is_same_v<V, double>, V>
-max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
+max(Ty *Address, V Val, atomic::OrderingTy Ordering,
+    MemScopeTy MemScope = MemScopeTy::device) {
   if (Val >= 0)
-    return utils::bitCast<double>(
-        max((int64_t *)Address, utils::bitCast<int64_t>(Val), Ordering));
-  return utils::bitCast<double>(
-      min((uint64_t *)Address, utils::bitCast<uint64_t>(Val), Ordering));
+    return utils::bitCast<double>(max(
+        (int64_t *)Address, utils::bitCast<int64_t>(Val), Ordering, MemScope));
+  return utils::bitCast<double>(min(
+      (uint64_t *)Address, utils::bitCast<uint64_t>(Val), Ordering, MemScope));
 }
 
 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
 utils::enable_if_t<!utils::is_floating_point_v<V>, V>
-min(Ty *Address, V Val, atomic::OrderingTy Ordering) {
-  return __scoped_atomic_fetch_min(Address, Val, Ordering,
-                                   __MEMORY_SCOPE_DEVICE);
+min(Ty *Address, V Val, atomic::OrderingTy Ordering,
+    MemScopeTy MemScope = MemScopeTy::device) {
+  return __scoped_atomic_fetch_min(Address, Val, Ordering, MemScope);
 }
 
 // TODO: Implement this with __atomic_fetch_max and remove the duplication.
 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
 utils::enable_if_t<utils::is_same_v<V, float>, V>
-min(Ty *Address, V Val, atomic::OrderingTy Ordering) {
+min(Ty *Address, V Val, atomic::OrderingTy Ordering,
+    MemScopeTy MemScope = MemScopeTy::device) {
   if (Val >= 0)
-    return utils::bitCast<float>(
-        min((int32_t *)Address, utils::bitCast<int32_t>(Val), Ordering));
-  return utils::bitCast<float>(
-      max((uint32_t *)Address, utils::bitCast<uint32_t>(Val), Ordering));
+    return utils::bitCast<float>(min(
+        (int32_t *)Address, utils::bitCast<int32_t>(Val), Ordering, MemScope));
+  return utils::bitCast<float>(max(
+      (uint32_t *)Address, utils::bitCast<uint32_t>(Val), Ordering, MemScope));
 }
 
 // TODO: Implement this with __atomic_fetch_max and remove the duplication.
 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
 utils::enable_if_t<utils::is_same_v<V, double>, V>
-min(Ty *Address, utils::remove_addrspace_t<Ty> Val,
-    atomic::OrderingTy Ordering) {
+min(Ty *Address, utils::remove_addrspace_t<Ty> Val, atomic::OrderingTy Ordering,
+    MemScopeTy MemScope = MemScopeTy::device) {
   if (Val >= 0)
-    return utils::bitCast<double>(
-        min((int64_t *)Address, utils::bitCast<int64_t>(Val), Ordering));
-  return utils::bitCast<double>(
-      max((uint64_t *)Address, utils::bitCast<uint64_t>(Val), Ordering));
+    return utils::bitCast<double>(min(
+        (int64_t *)Address, utils::bitCast<int64_t>(Val), Ordering, MemScope));
+  return utils::bitCast<double>(max(
+      (uint64_t *)Address, utils::bitCast<uint64_t>(Val), Ordering, MemScope));
 }
 
 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
-V bit_or(Ty *Address, V Val, atomic::OrderingTy Ordering) {
-  return __scoped_atomic_fetch_or(Address, Val, Ordering,
-                                  __MEMORY_SCOPE_DEVICE);
+V bit_or(Ty *Address, V Val, atomic::OrderingTy Ordering,
+         MemScopeTy MemScope = MemScopeTy::device) {
+  return __scoped_atomic_fetch_or(Address, Val, Ordering, MemScope);
 }
 
 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
-V bit_and(Ty *Address, V Val, atomic::OrderingTy Ordering) {
-  return __scoped_atomic_fetch_and(Address, Val, Ordering,
-                                   __MEMORY_SCOPE_DEVICE);
+V bit_and(Ty *Address, V Val, atomic::OrderingTy Ordering,
+          MemScopeTy MemScope = MemScopeTy::device) {
+  return __scoped_atomic_fetch_and(Address, Val, Ordering, MemScope);
 }
 
 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
-V bit_xor(Ty *Address, V Val, atomic::OrderingTy Ordering) {
-  return __scoped_atomic_fetch_xor(Address, Val, Ordering,
-                                   __MEMORY_SCOPE_DEVICE);
+V bit_xor(Ty *Address, V Val, atomic::OrderingTy Ordering,
+          MemScopeTy MemScope = MemScopeTy::device) {
+  return __scoped_atomic_fetch_xor(Address, Val, Ordering, MemScope);
 }
 
-static inline uint32_t atomicExchange(uint32_t *Address, uint32_t Val,
-                                      atomic::OrderingTy Ordering) {
+static inline uint32_t
+atomicExchange(uint32_t *Address, uint32_t Val, atomic::OrderingTy Ordering,
+               MemScopeTy MemScope = MemScopeTy::device) {
   uint32_t R;
-  __scoped_atomic_exchange(Address, &Val, &R, Ordering, __MEMORY_SCOPE_DEVICE);
+  __scoped_atomic_exchange(Address, &Val, &R, Ordering, MemScope);
   return R;
 }
 
diff --git a/offload/DeviceRTL/src/Synchronization.cpp b/offload/DeviceRTL/src/Synchronization.cpp
index 72a97ae3fcfb42..e0e277928fa910 100644
--- a/offload/DeviceRTL/src/Synchronization.cpp
+++ b/offload/DeviceRTL/src/Synchronization.cpp
@@ -64,12 +64,16 @@ uint32_t atomicInc(uint32_t *A, uint32_t V, atomic::OrderingTy Ordering,
 
 #define ScopeSwitch(ORDER)                                                     \
   switch (MemScope) {                                                          \
-  case atomic::MemScopeTy::all:                                                \
+  case atomic::MemScopeTy::system:                                             \
     return __builtin_amdgcn_atomic_inc32(A, V, ORDER, "");                     \
   case atomic::MemScopeTy::device:                                             \
     return __builtin_amdgcn_atomic_inc32(A, V, ORDER, "agent");                \
-  case atomic::MemScopeTy::cgroup:                                             \
+  case atomic::MemScopeTy::workgroup:                                          \
     return __builtin_amdgcn_atomic_inc32(A, V, ORDER, "workgroup");            \
+  case atomic::MemScopeTy::wavefront:                                          \
+    return __builtin_amdgcn_atomic_inc32(A, V, ORDER, "wavefront");            \
+  case atomic::MemScopeTy::single:                                             \
+    return __builtin_amdgcn_atomic_inc32(A, V, ORDER, "singlethread");         \
   }
 
 #define Case(ORDER)                                                            \
@@ -148,7 +152,7 @@ void fenceTeam(atomic::OrderingTy Ordering) {
 }
 
 void fenceKernel(atomic::OrderingTy Ordering) {
-  return __scoped_atomic_thread_fence(Ordering, atomic::device_);
+  return __scoped_atomic_thread_fence(Ordering, atomic::device);
 }
 
 void fenceSystem(atomic::OrderingTy Ordering) {

/// Atomically increment \p *Addr and wrap at \p V with \p Ordering semantics.
uint32_t inc(uint32_t *Addr, uint32_t V, OrderingTy Ordering,
MemScopeTy MemScope = MemScopeTy::all);
MemScopeTy MemScope = MemScopeTy::device);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes the default scope to device instead of system, which affects reductions. however, I believe it's fine since we don't need that part to be coherent with the CPU's memory.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran this through some buildbot tests and it seems fine. Don't know about impacts to downstream or so.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @ronlieb has already applied a downstream patch that includes it.

@jplehr jplehr requested a review from dhruvachak January 15, 2025 14:08
@jhuber6 jhuber6 requested a review from Meinersbur January 20, 2025 20:06
@jhuber6
Copy link
Contributor Author

jhuber6 commented Jan 20, 2025

ping

Copy link
Contributor

@shiltian shiltian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the change looks fine

@jhuber6 jhuber6 merged commit 3274bf6 into llvm:main Jan 21, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants