Skip to content

Commit 3274bf6

Browse files
authored
[OpenMP] Make each atomic helper take an atomic scope argument (#122786)
Summary: Right now we just default to device for each type, and mix an ad-hoc scope with the one used by the compiler's builtins. Unify this can make each version take the scope optionally. For @ronlieb, this will remove the need for `add_system` in the fork as well as the extra `cas` with system scope, just pass `system`.
1 parent 2d9f406 commit 3274bf6

File tree

2 files changed

+68
-63
lines changed

2 files changed

+68
-63
lines changed

offload/DeviceRTL/include/Synchronization.h

Lines changed: 61 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -28,144 +28,145 @@ enum OrderingTy {
2828
seq_cst = __ATOMIC_SEQ_CST,
2929
};
3030

31-
enum ScopeTy {
31+
enum MemScopeTy {
3232
system = __MEMORY_SCOPE_SYSTEM,
33-
device_ = __MEMORY_SCOPE_DEVICE,
33+
device = __MEMORY_SCOPE_DEVICE,
3434
workgroup = __MEMORY_SCOPE_WRKGRP,
3535
wavefront = __MEMORY_SCOPE_WVFRNT,
3636
single = __MEMORY_SCOPE_SINGLE,
3737
};
3838

39-
enum MemScopeTy {
40-
all, // All threads on all devices
41-
device, // All threads on the device
42-
cgroup // All threads in the contention group, e.g. the team
43-
};
44-
4539
/// Atomically increment \p *Addr and wrap at \p V with \p Ordering semantics.
4640
uint32_t inc(uint32_t *Addr, uint32_t V, OrderingTy Ordering,
47-
MemScopeTy MemScope = MemScopeTy::all);
41+
MemScopeTy MemScope = MemScopeTy::device);
4842

4943
/// Atomically perform <op> on \p V and \p *Addr with \p Ordering semantics. The
5044
/// result is stored in \p *Addr;
5145
/// {
5246

5347
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
5448
bool cas(Ty *Address, V ExpectedV, V DesiredV, atomic::OrderingTy OrderingSucc,
55-
atomic::OrderingTy OrderingFail) {
49+
atomic::OrderingTy OrderingFail,
50+
MemScopeTy MemScope = MemScopeTy::device) {
5651
return __scoped_atomic_compare_exchange(Address, &ExpectedV, &DesiredV, false,
57-
OrderingSucc, OrderingFail,
58-
__MEMORY_SCOPE_DEVICE);
52+
OrderingSucc, OrderingFail, MemScope);
5953
}
6054

6155
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
62-
V add(Ty *Address, V Val, atomic::OrderingTy Ordering) {
63-
return __scoped_atomic_fetch_add(Address, Val, Ordering,
64-
__MEMORY_SCOPE_DEVICE);
56+
V add(Ty *Address, V Val, atomic::OrderingTy Ordering,
57+
MemScopeTy MemScope = MemScopeTy::device) {
58+
return __scoped_atomic_fetch_add(Address, Val, Ordering, MemScope);
6559
}
6660

6761
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
68-
V load(Ty *Address, atomic::OrderingTy Ordering) {
69-
return __scoped_atomic_load_n(Address, Ordering, __MEMORY_SCOPE_DEVICE);
62+
V load(Ty *Address, atomic::OrderingTy Ordering,
63+
MemScopeTy MemScope = MemScopeTy::device) {
64+
return __scoped_atomic_load_n(Address, Ordering, MemScope);
7065
}
7166

7267
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
73-
void store(Ty *Address, V Val, atomic::OrderingTy Ordering) {
74-
__scoped_atomic_store_n(Address, Val, Ordering, __MEMORY_SCOPE_DEVICE);
68+
void store(Ty *Address, V Val, atomic::OrderingTy Ordering,
69+
MemScopeTy MemScope = MemScopeTy::device) {
70+
__scoped_atomic_store_n(Address, Val, Ordering, MemScope);
7571
}
7672

7773
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
78-
V mul(Ty *Address, V Val, atomic::OrderingTy Ordering) {
74+
V mul(Ty *Address, V Val, atomic::OrderingTy Ordering,
75+
MemScopeTy MemScope = MemScopeTy::device) {
7976
Ty TypedCurrentVal, TypedResultVal, TypedNewVal;
8077
bool Success;
8178
do {
8279
TypedCurrentVal = atomic::load(Address, Ordering);
8380
TypedNewVal = TypedCurrentVal * Val;
8481
Success = atomic::cas(Address, TypedCurrentVal, TypedNewVal, Ordering,
85-
atomic::relaxed);
82+
atomic::relaxed, MemScope);
8683
} while (!Success);
8784
return TypedResultVal;
8885
}
8986

9087
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
9188
utils::enable_if_t<!utils::is_floating_point_v<V>, V>
92-
max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
93-
return __scoped_atomic_fetch_max(Address, Val, Ordering,
94-
__MEMORY_SCOPE_DEVICE);
89+
max(Ty *Address, V Val, atomic::OrderingTy Ordering,
90+
MemScopeTy MemScope = MemScopeTy::device) {
91+
return __scoped_atomic_fetch_max(Address, Val, Ordering, MemScope);
9592
}
9693

9794
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
9895
utils::enable_if_t<utils::is_same_v<V, float>, V>
99-
max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
96+
max(Ty *Address, V Val, atomic::OrderingTy Ordering,
97+
MemScopeTy MemScope = MemScopeTy::device) {
10098
if (Val >= 0)
101-
return utils::bitCast<float>(
102-
max((int32_t *)Address, utils::bitCast<int32_t>(Val), Ordering));
103-
return utils::bitCast<float>(
104-
min((uint32_t *)Address, utils::bitCast<uint32_t>(Val), Ordering));
99+
return utils::bitCast<float>(max(
100+
(int32_t *)Address, utils::bitCast<int32_t>(Val), Ordering, MemScope));
101+
return utils::bitCast<float>(min(
102+
(uint32_t *)Address, utils::bitCast<uint32_t>(Val), Ordering, MemScope));
105103
}
106104

107105
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
108106
utils::enable_if_t<utils::is_same_v<V, double>, V>
109-
max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
107+
max(Ty *Address, V Val, atomic::OrderingTy Ordering,
108+
MemScopeTy MemScope = MemScopeTy::device) {
110109
if (Val >= 0)
111-
return utils::bitCast<double>(
112-
max((int64_t *)Address, utils::bitCast<int64_t>(Val), Ordering));
113-
return utils::bitCast<double>(
114-
min((uint64_t *)Address, utils::bitCast<uint64_t>(Val), Ordering));
110+
return utils::bitCast<double>(max(
111+
(int64_t *)Address, utils::bitCast<int64_t>(Val), Ordering, MemScope));
112+
return utils::bitCast<double>(min(
113+
(uint64_t *)Address, utils::bitCast<uint64_t>(Val), Ordering, MemScope));
115114
}
116115

117116
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
118117
utils::enable_if_t<!utils::is_floating_point_v<V>, V>
119-
min(Ty *Address, V Val, atomic::OrderingTy Ordering) {
120-
return __scoped_atomic_fetch_min(Address, Val, Ordering,
121-
__MEMORY_SCOPE_DEVICE);
118+
min(Ty *Address, V Val, atomic::OrderingTy Ordering,
119+
MemScopeTy MemScope = MemScopeTy::device) {
120+
return __scoped_atomic_fetch_min(Address, Val, Ordering, MemScope);
122121
}
123122

124123
// TODO: Implement this with __atomic_fetch_max and remove the duplication.
125124
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
126125
utils::enable_if_t<utils::is_same_v<V, float>, V>
127-
min(Ty *Address, V Val, atomic::OrderingTy Ordering) {
126+
min(Ty *Address, V Val, atomic::OrderingTy Ordering,
127+
MemScopeTy MemScope = MemScopeTy::device) {
128128
if (Val >= 0)
129-
return utils::bitCast<float>(
130-
min((int32_t *)Address, utils::bitCast<int32_t>(Val), Ordering));
131-
return utils::bitCast<float>(
132-
max((uint32_t *)Address, utils::bitCast<uint32_t>(Val), Ordering));
129+
return utils::bitCast<float>(min(
130+
(int32_t *)Address, utils::bitCast<int32_t>(Val), Ordering, MemScope));
131+
return utils::bitCast<float>(max(
132+
(uint32_t *)Address, utils::bitCast<uint32_t>(Val), Ordering, MemScope));
133133
}
134134

135135
// TODO: Implement this with __atomic_fetch_max and remove the duplication.
136136
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
137137
utils::enable_if_t<utils::is_same_v<V, double>, V>
138-
min(Ty *Address, utils::remove_addrspace_t<Ty> Val,
139-
atomic::OrderingTy Ordering) {
138+
min(Ty *Address, utils::remove_addrspace_t<Ty> Val, atomic::OrderingTy Ordering,
139+
MemScopeTy MemScope = MemScopeTy::device) {
140140
if (Val >= 0)
141-
return utils::bitCast<double>(
142-
min((int64_t *)Address, utils::bitCast<int64_t>(Val), Ordering));
143-
return utils::bitCast<double>(
144-
max((uint64_t *)Address, utils::bitCast<uint64_t>(Val), Ordering));
141+
return utils::bitCast<double>(min(
142+
(int64_t *)Address, utils::bitCast<int64_t>(Val), Ordering, MemScope));
143+
return utils::bitCast<double>(max(
144+
(uint64_t *)Address, utils::bitCast<uint64_t>(Val), Ordering, MemScope));
145145
}
146146

147147
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
148-
V bit_or(Ty *Address, V Val, atomic::OrderingTy Ordering) {
149-
return __scoped_atomic_fetch_or(Address, Val, Ordering,
150-
__MEMORY_SCOPE_DEVICE);
148+
V bit_or(Ty *Address, V Val, atomic::OrderingTy Ordering,
149+
MemScopeTy MemScope = MemScopeTy::device) {
150+
return __scoped_atomic_fetch_or(Address, Val, Ordering, MemScope);
151151
}
152152

153153
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
154-
V bit_and(Ty *Address, V Val, atomic::OrderingTy Ordering) {
155-
return __scoped_atomic_fetch_and(Address, Val, Ordering,
156-
__MEMORY_SCOPE_DEVICE);
154+
V bit_and(Ty *Address, V Val, atomic::OrderingTy Ordering,
155+
MemScopeTy MemScope = MemScopeTy::device) {
156+
return __scoped_atomic_fetch_and(Address, Val, Ordering, MemScope);
157157
}
158158

159159
template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
160-
V bit_xor(Ty *Address, V Val, atomic::OrderingTy Ordering) {
161-
return __scoped_atomic_fetch_xor(Address, Val, Ordering,
162-
__MEMORY_SCOPE_DEVICE);
160+
V bit_xor(Ty *Address, V Val, atomic::OrderingTy Ordering,
161+
MemScopeTy MemScope = MemScopeTy::device) {
162+
return __scoped_atomic_fetch_xor(Address, Val, Ordering, MemScope);
163163
}
164164

165-
static inline uint32_t atomicExchange(uint32_t *Address, uint32_t Val,
166-
atomic::OrderingTy Ordering) {
165+
static inline uint32_t
166+
atomicExchange(uint32_t *Address, uint32_t Val, atomic::OrderingTy Ordering,
167+
MemScopeTy MemScope = MemScopeTy::device) {
167168
uint32_t R;
168-
__scoped_atomic_exchange(Address, &Val, &R, Ordering, __MEMORY_SCOPE_DEVICE);
169+
__scoped_atomic_exchange(Address, &Val, &R, Ordering, MemScope);
169170
return R;
170171
}
171172

offload/DeviceRTL/src/Synchronization.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,16 @@ uint32_t atomicInc(uint32_t *A, uint32_t V, atomic::OrderingTy Ordering,
6464

6565
#define ScopeSwitch(ORDER) \
6666
switch (MemScope) { \
67-
case atomic::MemScopeTy::all: \
67+
case atomic::MemScopeTy::system: \
6868
return __builtin_amdgcn_atomic_inc32(A, V, ORDER, ""); \
6969
case atomic::MemScopeTy::device: \
7070
return __builtin_amdgcn_atomic_inc32(A, V, ORDER, "agent"); \
71-
case atomic::MemScopeTy::cgroup: \
71+
case atomic::MemScopeTy::workgroup: \
7272
return __builtin_amdgcn_atomic_inc32(A, V, ORDER, "workgroup"); \
73+
case atomic::MemScopeTy::wavefront: \
74+
return __builtin_amdgcn_atomic_inc32(A, V, ORDER, "wavefront"); \
75+
case atomic::MemScopeTy::single: \
76+
return __builtin_amdgcn_atomic_inc32(A, V, ORDER, "singlethread"); \
7377
}
7478

7579
#define Case(ORDER) \
@@ -148,7 +152,7 @@ void fenceTeam(atomic::OrderingTy Ordering) {
148152
}
149153

150154
void fenceKernel(atomic::OrderingTy Ordering) {
151-
return __scoped_atomic_thread_fence(Ordering, atomic::device_);
155+
return __scoped_atomic_thread_fence(Ordering, atomic::device);
152156
}
153157

154158
void fenceSystem(atomic::OrderingTy Ordering) {

0 commit comments

Comments
 (0)