Skip to content

[OpenMP] Adjust 'printf' handling in the OpenMP runtime #123670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 21, 2025

Conversation

jhuber6
Copy link
Contributor

@jhuber6 jhuber6 commented Jan 20, 2025

Summary:
We used to avoid a lot of this stuff because we didn't properly handle
variadics in device code. That's been solved for now, so we can just
make an internal printf handler that forwards to the external vprintf
function. This is either provided by NVIDIA's SDK or by the GPU libc
implementation.

The main reason for doing this is because it prevents the stupid AMDGPU
printf pass from mangling our beautiful printfs!

Summary:
We used to avoid a lot of this stuff because we didn't properly handle
variadics in device code. That's been solved for now, so we can just
make an internal printf handler that forwards to the external `vprintf`
function. This is either provided by NVIDIA's SDK or by the GPU libc
implementation.

The main reason for doing this is because it prevents the stupid AMDGPU
printf pass from mangling our beautiful printfs!
@llvmbot
Copy link
Member

llvmbot commented Jan 20, 2025

@llvm/pr-subscribers-offload

Author: Joseph Huber (jhuber6)

Changes

Summary:
We used to avoid a lot of this stuff because we didn't properly handle
variadics in device code. That's been solved for now, so we can just
make an internal printf handler that forwards to the external vprintf
function. This is either provided by NVIDIA's SDK or by the GPU libc
implementation.

The main reason for doing this is because it prevents the stupid AMDGPU
printf pass from mangling our beautiful printfs!


Full diff: https://github.com/llvm/llvm-project/pull/123670.diff

6 Files Affected:

  • (modified) offload/DeviceRTL/include/Debug.h (+1-6)
  • (modified) offload/DeviceRTL/include/LibC.h (+4-5)
  • (modified) offload/DeviceRTL/src/Debug.cpp (+2-2)
  • (modified) offload/DeviceRTL/src/LibC.cpp (+19-26)
  • (modified) offload/DeviceRTL/src/Parallelism.cpp (+2-1)
  • (modified) offload/DeviceRTL/src/State.cpp (+4-4)
diff --git a/offload/DeviceRTL/include/Debug.h b/offload/DeviceRTL/include/Debug.h
index 22998f44a5bea5..98d0fa498d952b 100644
--- a/offload/DeviceRTL/include/Debug.h
+++ b/offload/DeviceRTL/include/Debug.h
@@ -35,15 +35,10 @@ void __assert_fail_internal(const char *expr, const char *msg, const char *file,
       __assert_assume(expr);                                                   \
   }
 #define UNREACHABLE(msg)                                                       \
-  PRINT(msg);                                                                  \
+  printf(msg);                                                                 \
   __builtin_trap();                                                            \
   __builtin_unreachable();
 
 ///}
 
-#define PRINTF(fmt, ...) (void)printf(fmt, ##__VA_ARGS__);
-#define PRINT(str) PRINTF("%s", str)
-
-///}
-
 #endif
diff --git a/offload/DeviceRTL/include/LibC.h b/offload/DeviceRTL/include/LibC.h
index 03febdb5083423..94b5e651960674 100644
--- a/offload/DeviceRTL/include/LibC.h
+++ b/offload/DeviceRTL/include/LibC.h
@@ -14,11 +14,10 @@
 
 #include "DeviceTypes.h"
 
-extern "C" {
+namespace ompx {
 
-int memcmp(const void *lhs, const void *rhs, size_t count);
-void memset(void *dst, int C, size_t count);
-int printf(const char *format, ...);
-}
+int printf(const char *Format, ...);
+
+} // namespace ompx
 
 #endif
diff --git a/offload/DeviceRTL/src/Debug.cpp b/offload/DeviceRTL/src/Debug.cpp
index b451f17c6bbd89..1d9c9628854222 100644
--- a/offload/DeviceRTL/src/Debug.cpp
+++ b/offload/DeviceRTL/src/Debug.cpp
@@ -36,10 +36,10 @@ void __assert_assume(bool condition) { __builtin_assume(condition); }
 void __assert_fail_internal(const char *expr, const char *msg, const char *file,
                             unsigned line, const char *function) {
   if (msg) {
-    PRINTF("%s:%u: %s: Assertion %s (`%s`) failed.\n", file, line, function,
+    printf("%s:%u: %s: Assertion %s (`%s`) failed.\n", file, line, function,
            msg, expr);
   } else {
-    PRINTF("%s:%u: %s: Assertion `%s` failed.\n", file, line, function, expr);
+    printf("%s:%u: %s: Assertion `%s` failed.\n", file, line, function, expr);
   }
   __builtin_trap();
 }
diff --git a/offload/DeviceRTL/src/LibC.cpp b/offload/DeviceRTL/src/LibC.cpp
index 291ceb023a69c5..e55008f46269fe 100644
--- a/offload/DeviceRTL/src/LibC.cpp
+++ b/offload/DeviceRTL/src/LibC.cpp
@@ -10,32 +10,11 @@
 
 #pragma omp begin declare target device_type(nohost)
 
-namespace impl {
-int32_t omp_vprintf(const char *Format, __builtin_va_list vlist);
-}
-
-#ifndef OMPTARGET_HAS_LIBC
-namespace impl {
-#pragma omp begin declare variant match(                                       \
-        device = {arch(nvptx, nvptx64)},                                       \
-            implementation = {extension(match_any)})
-extern "C" int vprintf(const char *format, ...);
-int omp_vprintf(const char *Format, __builtin_va_list vlist) {
-  return vprintf(Format, vlist);
-}
-#pragma omp end declare variant
-
-#pragma omp begin declare variant match(device = {arch(amdgcn)})
-int omp_vprintf(const char *Format, __builtin_va_list) { return -1; }
-#pragma omp end declare variant
-} // namespace impl
-
-extern "C" int printf(const char *Format, ...) {
-  __builtin_va_list vlist;
-  __builtin_va_start(vlist, Format);
-  return impl::omp_vprintf(Format, vlist);
-}
-#endif // OMPTARGET_HAS_LIBC
+#if defined(__AMDGPU__) && !defined(OMPTARGET_HAS_LIBC)
+extern "C" int vprintf(const char *format, __builtin_va_list) { return -1; }
+#else
+extern "C" int vprintf(const char *format, __builtin_va_list);
+#endif
 
 extern "C" {
 [[gnu::weak]] int memcmp(const void *lhs, const void *rhs, size_t count) {
@@ -54,6 +33,20 @@ extern "C" {
   for (size_t I = 0; I < count; ++I)
     dstc[I] = C;
 }
+
+[[gnu::weak]] int printf(const char *Format, ...) {
+  __builtin_va_list vlist;
+  __builtin_va_start(vlist, Format);
+  return ::vprintf(Format, vlist);
+}
+}
+
+namespace ompx {
+[[clang::no_builtin("printf")]] int printf(const char *Format, ...) {
+  __builtin_va_list vlist;
+  __builtin_va_start(vlist, Format);
+  return ::vprintf(Format, vlist);
 }
+} // namespace ompx
 
 #pragma omp end declare target
diff --git a/offload/DeviceRTL/src/Parallelism.cpp b/offload/DeviceRTL/src/Parallelism.cpp
index 5286d53b623f0a..a87e363349b1e5 100644
--- a/offload/DeviceRTL/src/Parallelism.cpp
+++ b/offload/DeviceRTL/src/Parallelism.cpp
@@ -36,6 +36,7 @@
 #include "DeviceTypes.h"
 #include "DeviceUtils.h"
 #include "Interface.h"
+#include "LibC.h"
 #include "Mapping.h"
 #include "State.h"
 #include "Synchronization.h"
@@ -74,7 +75,7 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
   switch (nargs) {
 #include "generated_microtask_cases.gen"
   default:
-    PRINT("Too many arguments in kmp_invoke_microtask, aborting execution.\n");
+    printf("Too many arguments in kmp_invoke_microtask, aborting execution.\n");
     __builtin_trap();
   }
 }
diff --git a/offload/DeviceRTL/src/State.cpp b/offload/DeviceRTL/src/State.cpp
index 855c74fa58e0a5..100bc8ab47983c 100644
--- a/offload/DeviceRTL/src/State.cpp
+++ b/offload/DeviceRTL/src/State.cpp
@@ -138,8 +138,8 @@ void *SharedMemorySmartStackTy::push(uint64_t Bytes) {
   }
 
   if (config::isDebugMode(DeviceDebugKind::CommonIssues))
-    PRINT("Shared memory stack full, fallback to dynamic allocation of global "
-          "memory will negatively impact performance.\n");
+    printf("Shared memory stack full, fallback to dynamic allocation of global "
+           "memory will negatively impact performance.\n");
   void *GlobalMemory = memory::allocGlobal(
       AlignedBytes, "Slow path shared memory allocation, insufficient "
                     "shared memory stack memory!");
@@ -173,7 +173,7 @@ void memory::freeShared(void *Ptr, uint64_t Bytes, const char *Reason) {
 void *memory::allocGlobal(uint64_t Bytes, const char *Reason) {
   void *Ptr = malloc(Bytes);
   if (config::isDebugMode(DeviceDebugKind::CommonIssues) && Ptr == nullptr)
-    PRINT("nullptr returned by malloc!\n");
+    printf("nullptr returned by malloc!\n");
   return Ptr;
 }
 
@@ -277,7 +277,7 @@ void state::enterDataEnvironment(IdentTy *Ident) {
         sizeof(ThreadStates[0]) * mapping::getNumberOfThreadsInBlock();
     void *ThreadStatesPtr =
         memory::allocGlobal(Bytes, "Thread state array allocation");
-    memset(ThreadStatesPtr, 0, Bytes);
+    __builtin_memset(ThreadStatesPtr, 0, Bytes);
     if (!atomic::cas(ThreadStatesBitsPtr, uintptr_t(0),
                      reinterpret_cast<uintptr_t>(ThreadStatesPtr),
                      atomic::seq_cst, atomic::seq_cst))

@arsenm
Copy link
Contributor

arsenm commented Jan 21, 2025

The main reason for doing this is because it prevents the stupid AMDGPU
printf pass from mangling our beautiful printfs!

We should probably find a better solution to this part

@jhuber6
Copy link
Contributor Author

jhuber6 commented Jan 21, 2025

The main reason for doing this is because it prevents the stupid AMDGPU
printf pass from mangling our beautiful printfs!

We should probably find a better solution to this part

My solution is to delete it and use my solution of just calling a function called printf, but that's probably not popular internally. So, to do the easy workaround, just stop calling the C symbol printf in the runtime.

@jhuber6
Copy link
Contributor Author

jhuber6 commented Jan 21, 2025

Alternatively, I just make an -mllvm option to turn it off. I can't use -fno-builtin on the compilation because that would horrifically mangle performance by not allowing the OpenMP runtime to be inlined ever.

@arsenm
Copy link
Contributor

arsenm commented Jan 21, 2025

The printf binding pass should be added by clang, not the backend. It's specific to that path for OpenCL. Plus, the pass does respect nobuiltin

@jhuber6
Copy link
Contributor Author

jhuber6 commented Jan 21, 2025

The printf binding pass should be added by clang, not the backend. It's specific to that path for OpenCL. Plus, the pass does respect nobuiltin

I thought it was used for HIP as well, but I definitely wouldn't be opposed since it screws with my C stuff. Though in general this patch does clean up the printf handling in the OpenMP runtime a lot so it should land as well.

@arsenm
Copy link
Contributor

arsenm commented Jan 21, 2025

I thought it was used for HIP as well, but I definitely wouldn't be opposed since it screws with my C stuff. Though in general this patch does clean up the printf handling in the OpenMP runtime a lot so it should land as well.

No, see clang/test/CodeGenHIP/printf-builtin.hip. It has 2 options and both directly emit the inlined implementation with __ockl functions

@jhuber6
Copy link
Contributor Author

jhuber6 commented Jan 21, 2025

I thought it was used for HIP as well, but I definitely wouldn't be opposed since it screws with my C stuff. Though in general this patch does clean up the printf handling in the OpenMP runtime a lot so it should land as well.

No, see clang/test/CodeGenHIP/printf-builtin.hip. It has 2 options and both directly emit the inlined implementation with __ockl functions

Alright, I can make another patch to do that since I always get bit by this while trying to use my infra.

@jhuber6 jhuber6 merged commit 2d9f406 into llvm:main Jan 21, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants