From e663bc98d19b80519460788bd5793316259cfc8c Mon Sep 17 00:00:00 2001 From: Joseph Huber Date: Mon, 20 Jan 2025 15:20:59 -0600 Subject: [PATCH] [OpenMP] Adjust 'printf' handling in the OpenMP runtime Summary: We used to avoid a lot of this stuff because we didn't properly handle variadics in device code. That's been solved for now, so we can just make an internal printf handler that forwards to the external `vprintf` function. This is either provided by NVIDIA's SDK or by the GPU libc implementation. The main reason for doing this is because it prevents the stupid AMDGPU printf pass from mangling our beautiful printfs! --- offload/DeviceRTL/include/Debug.h | 7 +---- offload/DeviceRTL/include/LibC.h | 9 +++--- offload/DeviceRTL/src/Debug.cpp | 4 +-- offload/DeviceRTL/src/LibC.cpp | 45 +++++++++++---------------- offload/DeviceRTL/src/Parallelism.cpp | 3 +- offload/DeviceRTL/src/State.cpp | 8 ++--- 6 files changed, 32 insertions(+), 44 deletions(-) diff --git a/offload/DeviceRTL/include/Debug.h b/offload/DeviceRTL/include/Debug.h index 22998f44a5bea..98d0fa498d952 100644 --- a/offload/DeviceRTL/include/Debug.h +++ b/offload/DeviceRTL/include/Debug.h @@ -35,15 +35,10 @@ void __assert_fail_internal(const char *expr, const char *msg, const char *file, __assert_assume(expr); \ } #define UNREACHABLE(msg) \ - PRINT(msg); \ + printf(msg); \ __builtin_trap(); \ __builtin_unreachable(); ///} -#define PRINTF(fmt, ...) (void)printf(fmt, ##__VA_ARGS__); -#define PRINT(str) PRINTF("%s", str) - -///} - #endif diff --git a/offload/DeviceRTL/include/LibC.h b/offload/DeviceRTL/include/LibC.h index 03febdb508342..94b5e65196067 100644 --- a/offload/DeviceRTL/include/LibC.h +++ b/offload/DeviceRTL/include/LibC.h @@ -14,11 +14,10 @@ #include "DeviceTypes.h" -extern "C" { +namespace ompx { -int memcmp(const void *lhs, const void *rhs, size_t count); -void memset(void *dst, int C, size_t count); -int printf(const char *format, ...); -} +int printf(const char *Format, ...); + +} // namespace ompx #endif diff --git a/offload/DeviceRTL/src/Debug.cpp b/offload/DeviceRTL/src/Debug.cpp index b451f17c6bbd8..1d9c962885422 100644 --- a/offload/DeviceRTL/src/Debug.cpp +++ b/offload/DeviceRTL/src/Debug.cpp @@ -36,10 +36,10 @@ void __assert_assume(bool condition) { __builtin_assume(condition); } void __assert_fail_internal(const char *expr, const char *msg, const char *file, unsigned line, const char *function) { if (msg) { - PRINTF("%s:%u: %s: Assertion %s (`%s`) failed.\n", file, line, function, + printf("%s:%u: %s: Assertion %s (`%s`) failed.\n", file, line, function, msg, expr); } else { - PRINTF("%s:%u: %s: Assertion `%s` failed.\n", file, line, function, expr); + printf("%s:%u: %s: Assertion `%s` failed.\n", file, line, function, expr); } __builtin_trap(); } diff --git a/offload/DeviceRTL/src/LibC.cpp b/offload/DeviceRTL/src/LibC.cpp index 291ceb023a69c..e55008f46269f 100644 --- a/offload/DeviceRTL/src/LibC.cpp +++ b/offload/DeviceRTL/src/LibC.cpp @@ -10,32 +10,11 @@ #pragma omp begin declare target device_type(nohost) -namespace impl { -int32_t omp_vprintf(const char *Format, __builtin_va_list vlist); -} - -#ifndef OMPTARGET_HAS_LIBC -namespace impl { -#pragma omp begin declare variant match( \ - device = {arch(nvptx, nvptx64)}, \ - implementation = {extension(match_any)}) -extern "C" int vprintf(const char *format, ...); -int omp_vprintf(const char *Format, __builtin_va_list vlist) { - return vprintf(Format, vlist); -} -#pragma omp end declare variant - -#pragma omp begin declare variant match(device = {arch(amdgcn)}) -int omp_vprintf(const char *Format, __builtin_va_list) { return -1; } -#pragma omp end declare variant -} // namespace impl - -extern "C" int printf(const char *Format, ...) { - __builtin_va_list vlist; - __builtin_va_start(vlist, Format); - return impl::omp_vprintf(Format, vlist); -} -#endif // OMPTARGET_HAS_LIBC +#if defined(__AMDGPU__) && !defined(OMPTARGET_HAS_LIBC) +extern "C" int vprintf(const char *format, __builtin_va_list) { return -1; } +#else +extern "C" int vprintf(const char *format, __builtin_va_list); +#endif extern "C" { [[gnu::weak]] int memcmp(const void *lhs, const void *rhs, size_t count) { @@ -54,6 +33,20 @@ extern "C" { for (size_t I = 0; I < count; ++I) dstc[I] = C; } + +[[gnu::weak]] int printf(const char *Format, ...) { + __builtin_va_list vlist; + __builtin_va_start(vlist, Format); + return ::vprintf(Format, vlist); +} +} + +namespace ompx { +[[clang::no_builtin("printf")]] int printf(const char *Format, ...) { + __builtin_va_list vlist; + __builtin_va_start(vlist, Format); + return ::vprintf(Format, vlist); } +} // namespace ompx #pragma omp end declare target diff --git a/offload/DeviceRTL/src/Parallelism.cpp b/offload/DeviceRTL/src/Parallelism.cpp index 5286d53b623f0..a87e363349b1e 100644 --- a/offload/DeviceRTL/src/Parallelism.cpp +++ b/offload/DeviceRTL/src/Parallelism.cpp @@ -36,6 +36,7 @@ #include "DeviceTypes.h" #include "DeviceUtils.h" #include "Interface.h" +#include "LibC.h" #include "Mapping.h" #include "State.h" #include "Synchronization.h" @@ -74,7 +75,7 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) { switch (nargs) { #include "generated_microtask_cases.gen" default: - PRINT("Too many arguments in kmp_invoke_microtask, aborting execution.\n"); + printf("Too many arguments in kmp_invoke_microtask, aborting execution.\n"); __builtin_trap(); } } diff --git a/offload/DeviceRTL/src/State.cpp b/offload/DeviceRTL/src/State.cpp index 855c74fa58e0a..100bc8ab47983 100644 --- a/offload/DeviceRTL/src/State.cpp +++ b/offload/DeviceRTL/src/State.cpp @@ -138,8 +138,8 @@ void *SharedMemorySmartStackTy::push(uint64_t Bytes) { } if (config::isDebugMode(DeviceDebugKind::CommonIssues)) - PRINT("Shared memory stack full, fallback to dynamic allocation of global " - "memory will negatively impact performance.\n"); + printf("Shared memory stack full, fallback to dynamic allocation of global " + "memory will negatively impact performance.\n"); void *GlobalMemory = memory::allocGlobal( AlignedBytes, "Slow path shared memory allocation, insufficient " "shared memory stack memory!"); @@ -173,7 +173,7 @@ void memory::freeShared(void *Ptr, uint64_t Bytes, const char *Reason) { void *memory::allocGlobal(uint64_t Bytes, const char *Reason) { void *Ptr = malloc(Bytes); if (config::isDebugMode(DeviceDebugKind::CommonIssues) && Ptr == nullptr) - PRINT("nullptr returned by malloc!\n"); + printf("nullptr returned by malloc!\n"); return Ptr; } @@ -277,7 +277,7 @@ void state::enterDataEnvironment(IdentTy *Ident) { sizeof(ThreadStates[0]) * mapping::getNumberOfThreadsInBlock(); void *ThreadStatesPtr = memory::allocGlobal(Bytes, "Thread state array allocation"); - memset(ThreadStatesPtr, 0, Bytes); + __builtin_memset(ThreadStatesPtr, 0, Bytes); if (!atomic::cas(ThreadStatesBitsPtr, uintptr_t(0), reinterpret_cast(ThreadStatesPtr), atomic::seq_cst, atomic::seq_cst))