From 96cc59a4eae2a950f956dc6c1a65d634059e8845 Mon Sep 17 00:00:00 2001 From: Luke Drummond Date: Fri, 9 Jun 2023 18:26:46 +0100 Subject: [PATCH] Add prefetch for HIP USM allocations This change is necessary to workaround a delightful bug in either HIP runtime, or the HIP spec. It's discussed at length in github.com/intel/llvm/issues/7252 but for the purposes of this patch, it suffices to say that a call to `hipMemPrefetchAsync` is *required* for correctness in the face of global atomic operations on (*at least*) shared USM allocations. The architecture of this change is slightly strange on first sight in that we reduntantly track allocation information in several places. The context now keeps track of all USM mappings. We require a mapping of pointers to the allocated size, but these allocations aren't pinned to any particular queue or HIP stream. The `hipMemPrefetchAsync`, however, requires the associated HIP stream object, and the size of the allocation. The stream comes hot-off-the-queue *only* just before a kernel is launched, so we need to defer the prefetch until we have that information. Finally, the kernel itself keeps track of pointer arguments in a more accessible way so we can determine which of the kernel's pointer arguments do, in-fact, point to USM allocations. --- .../ur/adapters/hip/context.hpp | 54 +++++++++++++++++++ .../ur/adapters/hip/enqueue.cpp | 14 ++++- .../ur/adapters/hip/kernel.cpp | 2 +- .../ur/adapters/hip/kernel.hpp | 15 ++++++ .../unified_runtime/ur/adapters/hip/usm.cpp | 13 ++--- 5 files changed, 89 insertions(+), 9 deletions(-) diff --git a/sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp b/sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp index f504bb01ce0bf..7d4fe0c26a424 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp +++ b/sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp @@ -7,6 +7,8 @@ //===-----------------------------------------------------------------===// #pragma once +#include + #include "common.hpp" #include "device.hpp" #include "platform.hpp" @@ -93,9 +95,61 @@ struct ur_context_handle_t_ { uint32_t getReferenceCount() const noexcept { return RefCount; } + /// We need to keep track of USM mappings in AMD HIP, as certain extra + /// synchronization *is* actually required for correctness. + /// During kernel enqueue we must dispatch a prefetch for each kernel argument + /// that points to a USM mapping to ensure the mapping is correctly + /// populated on the device (https://github.com/intel/llvm/issues/7252). Thus, + /// we keep track of mappings in the context, and then check against them just + /// before the kernel is launched. The stream against which the kernel is + /// launched is not known until enqueue time, but the USM mappings can happen + /// at any time. Thus, they are tracked on the context used for the urUSM* + /// mapping. + /// + /// The three utility function are simple wrappers around a mapping from a + /// pointer to a size. + void addUSMMapping(void *Ptr, size_t Size) { + std::lock_guard Guard(Mutex); + assert(USMMappings.find(Ptr) == USMMappings.end() && + "mapping already exists"); + USMMappings[Ptr] = Size; + } + + void removeUSMMapping(const void *Ptr) { + std::lock_guard guard(Mutex); + auto It = USMMappings.find(Ptr); + if (It != USMMappings.end()) + USMMappings.erase(It); + } + + std::pair getUSMMapping(const void *Ptr) { + std::lock_guard Guard(Mutex); + auto It = USMMappings.find(Ptr); + // The simple case is the fast case... + if (It != USMMappings.end()) + return *It; + + // ... but in the failure case we have to fall back to a full scan to search + // for "offset" pointers in case the user passes in the middle of an + // allocation. We have to do some not-so-ordained-by-the-standard ordered + // comparisons of pointers here, but it'll work on all platforms we support. + uintptr_t PtrVal = (uintptr_t)Ptr; + for (std::pair Pair : USMMappings) { + uintptr_t BaseAddr = (uintptr_t)Pair.first; + uintptr_t EndAddr = BaseAddr + Pair.second; + if (PtrVal > BaseAddr && PtrVal < EndAddr) { + // If we've found something now, offset *must* be nonzero + assert(Pair.second); + return Pair; + } + } + return {nullptr, 0}; + } + private: std::mutex Mutex; std::vector ExtendedDeleters; + std::unordered_map USMMappings; }; namespace { diff --git a/sycl/plugins/unified_runtime/ur/adapters/hip/enqueue.cpp b/sycl/plugins/unified_runtime/ur/adapters/hip/enqueue.cpp index 1b0b2acc2a3f8..7b36c1863fcc8 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/hip/enqueue.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/hip/enqueue.cpp @@ -252,7 +252,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( std::unique_ptr RetImplEvent{nullptr}; try { - ScopedContext Active(hQueue->getDevice()); + ur_device_handle_t Dev = hQueue->getDevice(); + ScopedContext Active(Dev); + ur_context_handle_t Ctx = hQueue->getContext(); uint32_t StreamToken; ur_stream_quard Guard; @@ -260,6 +262,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( numEventsInWaitList, phEventWaitList, Guard, &StreamToken); hipFunction_t HIPFunc = hKernel->get(); + hipDevice_t HIPDev = Dev->get(); + for (const void *P : hKernel->getPtrArgs()) { + auto [Addr, Size] = Ctx->getUSMMapping(P); + if (!Addr) + continue; + if (hipMemPrefetchAsync(Addr, Size, HIPDev, HIPStream) != hipSuccess) + return UR_RESULT_ERROR_INVALID_KERNEL_ARGS; + } Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, phEventWaitList); @@ -301,7 +311,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( int DeviceMaxLocalMem = 0; Result = UR_CHECK_ERROR(hipDeviceGetAttribute( &DeviceMaxLocalMem, hipDeviceAttributeMaxSharedMemoryPerBlock, - hQueue->getDevice()->get())); + HIPDev)); static const int EnvVal = std::atoi(LocalMemSzPtr); if (EnvVal <= 0 || EnvVal > DeviceMaxLocalMem) { diff --git a/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.cpp b/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.cpp index 8da2d969c2c55..93d431989617c 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.cpp @@ -256,7 +256,7 @@ urKernelGetSubGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgPointer( ur_kernel_handle_t hKernel, uint32_t argIndex, const ur_kernel_arg_pointer_properties_t *, const void *pArgValue) { - hKernel->setKernelArg(argIndex, sizeof(pArgValue), pArgValue); + hKernel->setKernelPtrArg(argIndex, sizeof(pArgValue), pArgValue); return UR_RESULT_SUCCESS; } diff --git a/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.hpp b/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.hpp index 0e4f3c0ea8bd0..1e2bd03a0f1ea 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.hpp +++ b/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.hpp @@ -12,6 +12,7 @@ #include #include #include +#include #include "program.hpp" @@ -55,6 +56,7 @@ struct ur_kernel_handle_t_ { args_size_t ParamSizes; args_index_t Indices; args_size_t OffsetPerIndex; + std::set PtrArgs; std::uint32_t ImplicitOffsetArgs[3] = {0, 0, 0}; @@ -175,6 +177,19 @@ struct ur_kernel_handle_t_ { Args.addArg(Index, Size, Arg); } + /// We track all pointer arguments to be able to issue prefetches at enqueue + /// time + void setKernelPtrArg(int Index, size_t Size, const void *PtrArg) { + Args.PtrArgs.insert(*static_cast(PtrArg)); + setKernelArg(Index, Size, PtrArg); + } + + bool isPtrArg(const void *ptr) { + return Args.PtrArgs.find(ptr) != Args.PtrArgs.end(); + } + + std::set &getPtrArgs() { return Args.PtrArgs; } + void setKernelLocalArg(int Index, size_t Size) { Args.addLocalArg(Index, Size); } diff --git a/sycl/plugins/unified_runtime/ur/adapters/hip/usm.cpp b/sycl/plugins/unified_runtime/ur/adapters/hip/usm.cpp index 296954268a818..f7699441143d3 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/hip/usm.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/hip/usm.cpp @@ -28,14 +28,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc( ScopedContext Active(hContext->getDevice()); Result = UR_CHECK_ERROR(hipHostMalloc(ppMem, size)); } catch (ur_result_t Error) { - Result = Error; + return Error; } if (Result == UR_RESULT_SUCCESS) { assert((!pUSMDesc || pUSMDesc->align == 0 || reinterpret_cast(*ppMem) % pUSMDesc->align == 0)); + hContext->addUSMMapping(*ppMem, size); } - return Result; } @@ -53,14 +53,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc( ScopedContext Active(hContext->getDevice()); Result = UR_CHECK_ERROR(hipMalloc(ppMem, size)); } catch (ur_result_t Error) { - Result = Error; + return Error; } if (Result == UR_RESULT_SUCCESS) { assert((!pUSMDesc || pUSMDesc->align == 0 || reinterpret_cast(*ppMem) % pUSMDesc->align == 0)); + hContext->addUSMMapping(*ppMem, size); } - return Result; } @@ -84,8 +84,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc( if (Result == UR_RESULT_SUCCESS) { assert((!pUSMDesc || pUSMDesc->align == 0 || reinterpret_cast(*ppMem) % pUSMDesc->align == 0)); + hContext->addUSMMapping(*ppMem, size); } - return Result; } @@ -109,8 +109,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext, Result = UR_CHECK_ERROR(hipFreeHost(pMem)); } } catch (ur_result_t Error) { - Result = Error; + return Error; } + hContext->removeUSMMapping(pMem); return Result; }