Skip to content

Commit 2d9f406

Browse files
authored
[OpenMP] Adjust 'printf' handling in the OpenMP runtime (#123670)
Summary: We used to avoid a lot of this stuff because we didn't properly handle variadics in device code. That's been solved for now, so we can just make an internal printf handler that forwards to the external `vprintf` function. This is either provided by NVIDIA's SDK or by the GPU libc implementation. The main reason for doing this is because it prevents the stupid AMDGPU printf pass from mangling our beautiful printfs!
1 parent 585858a commit 2d9f406

File tree

6 files changed

+32
-44
lines changed

6 files changed

+32
-44
lines changed

offload/DeviceRTL/include/Debug.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,10 @@ void __assert_fail_internal(const char *expr, const char *msg, const char *file,
3535
__assert_assume(expr); \
3636
}
3737
#define UNREACHABLE(msg) \
38-
PRINT(msg); \
38+
printf(msg); \
3939
__builtin_trap(); \
4040
__builtin_unreachable();
4141

4242
///}
4343

44-
#define PRINTF(fmt, ...) (void)printf(fmt, ##__VA_ARGS__);
45-
#define PRINT(str) PRINTF("%s", str)
46-
47-
///}
48-
4944
#endif

offload/DeviceRTL/include/LibC.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414

1515
#include "DeviceTypes.h"
1616

17-
extern "C" {
17+
namespace ompx {
1818

19-
int memcmp(const void *lhs, const void *rhs, size_t count);
20-
void memset(void *dst, int C, size_t count);
21-
int printf(const char *format, ...);
22-
}
19+
int printf(const char *Format, ...);
20+
21+
} // namespace ompx
2322

2423
#endif

offload/DeviceRTL/src/Debug.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ void __assert_assume(bool condition) { __builtin_assume(condition); }
3636
void __assert_fail_internal(const char *expr, const char *msg, const char *file,
3737
unsigned line, const char *function) {
3838
if (msg) {
39-
PRINTF("%s:%u: %s: Assertion %s (`%s`) failed.\n", file, line, function,
39+
printf("%s:%u: %s: Assertion %s (`%s`) failed.\n", file, line, function,
4040
msg, expr);
4141
} else {
42-
PRINTF("%s:%u: %s: Assertion `%s` failed.\n", file, line, function, expr);
42+
printf("%s:%u: %s: Assertion `%s` failed.\n", file, line, function, expr);
4343
}
4444
__builtin_trap();
4545
}

offload/DeviceRTL/src/LibC.cpp

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,11 @@
1010

1111
#pragma omp begin declare target device_type(nohost)
1212

13-
namespace impl {
14-
int32_t omp_vprintf(const char *Format, __builtin_va_list vlist);
15-
}
16-
17-
#ifndef OMPTARGET_HAS_LIBC
18-
namespace impl {
19-
#pragma omp begin declare variant match( \
20-
device = {arch(nvptx, nvptx64)}, \
21-
implementation = {extension(match_any)})
22-
extern "C" int vprintf(const char *format, ...);
23-
int omp_vprintf(const char *Format, __builtin_va_list vlist) {
24-
return vprintf(Format, vlist);
25-
}
26-
#pragma omp end declare variant
27-
28-
#pragma omp begin declare variant match(device = {arch(amdgcn)})
29-
int omp_vprintf(const char *Format, __builtin_va_list) { return -1; }
30-
#pragma omp end declare variant
31-
} // namespace impl
32-
33-
extern "C" int printf(const char *Format, ...) {
34-
__builtin_va_list vlist;
35-
__builtin_va_start(vlist, Format);
36-
return impl::omp_vprintf(Format, vlist);
37-
}
38-
#endif // OMPTARGET_HAS_LIBC
13+
#if defined(__AMDGPU__) && !defined(OMPTARGET_HAS_LIBC)
14+
extern "C" int vprintf(const char *format, __builtin_va_list) { return -1; }
15+
#else
16+
extern "C" int vprintf(const char *format, __builtin_va_list);
17+
#endif
3918

4019
extern "C" {
4120
[[gnu::weak]] int memcmp(const void *lhs, const void *rhs, size_t count) {
@@ -54,6 +33,20 @@ extern "C" {
5433
for (size_t I = 0; I < count; ++I)
5534
dstc[I] = C;
5635
}
36+
37+
[[gnu::weak]] int printf(const char *Format, ...) {
38+
__builtin_va_list vlist;
39+
__builtin_va_start(vlist, Format);
40+
return ::vprintf(Format, vlist);
41+
}
42+
}
43+
44+
namespace ompx {
45+
[[clang::no_builtin("printf")]] int printf(const char *Format, ...) {
46+
__builtin_va_list vlist;
47+
__builtin_va_start(vlist, Format);
48+
return ::vprintf(Format, vlist);
5749
}
50+
} // namespace ompx
5851

5952
#pragma omp end declare target

offload/DeviceRTL/src/Parallelism.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "DeviceTypes.h"
3737
#include "DeviceUtils.h"
3838
#include "Interface.h"
39+
#include "LibC.h"
3940
#include "Mapping.h"
4041
#include "State.h"
4142
#include "Synchronization.h"
@@ -74,7 +75,7 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
7475
switch (nargs) {
7576
#include "generated_microtask_cases.gen"
7677
default:
77-
PRINT("Too many arguments in kmp_invoke_microtask, aborting execution.\n");
78+
printf("Too many arguments in kmp_invoke_microtask, aborting execution.\n");
7879
__builtin_trap();
7980
}
8081
}

offload/DeviceRTL/src/State.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ void *SharedMemorySmartStackTy::push(uint64_t Bytes) {
138138
}
139139

140140
if (config::isDebugMode(DeviceDebugKind::CommonIssues))
141-
PRINT("Shared memory stack full, fallback to dynamic allocation of global "
142-
"memory will negatively impact performance.\n");
141+
printf("Shared memory stack full, fallback to dynamic allocation of global "
142+
"memory will negatively impact performance.\n");
143143
void *GlobalMemory = memory::allocGlobal(
144144
AlignedBytes, "Slow path shared memory allocation, insufficient "
145145
"shared memory stack memory!");
@@ -173,7 +173,7 @@ void memory::freeShared(void *Ptr, uint64_t Bytes, const char *Reason) {
173173
void *memory::allocGlobal(uint64_t Bytes, const char *Reason) {
174174
void *Ptr = malloc(Bytes);
175175
if (config::isDebugMode(DeviceDebugKind::CommonIssues) && Ptr == nullptr)
176-
PRINT("nullptr returned by malloc!\n");
176+
printf("nullptr returned by malloc!\n");
177177
return Ptr;
178178
}
179179

@@ -277,7 +277,7 @@ void state::enterDataEnvironment(IdentTy *Ident) {
277277
sizeof(ThreadStates[0]) * mapping::getNumberOfThreadsInBlock();
278278
void *ThreadStatesPtr =
279279
memory::allocGlobal(Bytes, "Thread state array allocation");
280-
memset(ThreadStatesPtr, 0, Bytes);
280+
__builtin_memset(ThreadStatesPtr, 0, Bytes);
281281
if (!atomic::cas(ThreadStatesBitsPtr, uintptr_t(0),
282282
reinterpret_cast<uintptr_t>(ThreadStatesPtr),
283283
atomic::seq_cst, atomic::seq_cst))

0 commit comments

Comments
 (0)