diff --git a/sycl/include/CL/sycl/detail/cg.hpp b/sycl/include/CL/sycl/detail/cg.hpp index 2ef0ff5170b74..8feb61a2f0568 100644 --- a/sycl/include/CL/sycl/detail/cg.hpp +++ b/sycl/include/CL/sycl/detail/cg.hpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -217,12 +218,17 @@ class InteropTask { class HostTask { std::function MHostTask; + std::function MInteropTask; public: HostTask() : MHostTask([]() {}) {} HostTask(std::function &&Func) : MHostTask(Func) {} + HostTask(std::function &&Func) : MInteropTask(Func) {} + + bool isInteropTask() const { return !!MInteropTask; } void call() { MHostTask(); } + void call(interop_handle handle) { MInteropTask(handle); } }; // Class which stores specific lambda object. @@ -645,9 +651,16 @@ class CGInteropTask : public CG { class CGHostTask : public CG { public: std::unique_ptr MHostTask; + // queue for host-interop task + shared_ptr_class MQueue; + // context for host-interop task + shared_ptr_class MContext; vector_class MArgs; - CGHostTask(std::unique_ptr HostTask, vector_class Args, + CGHostTask(std::unique_ptr HostTask, + std::shared_ptr Queue, + std::shared_ptr Context, + vector_class Args, std::vector> ArgsStorage, std::vector AccStorage, std::vector> SharedPtrStorage, @@ -657,7 +670,8 @@ class CGHostTask : public CG { : CG(Type, std::move(ArgsStorage), std::move(AccStorage), std::move(SharedPtrStorage), std::move(Requirements), std::move(Events), std::move(loc)), - MHostTask(std::move(HostTask)), MArgs(std::move(Args)) {} + MHostTask(std::move(HostTask)), MQueue(Queue), MContext(Context), + MArgs(std::move(Args)) {} }; } // namespace detail diff --git a/sycl/include/CL/sycl/handler.hpp b/sycl/include/CL/sycl/handler.hpp index c65eea75bb7ba..1320df6d2906b 100644 --- a/sycl/include/CL/sycl/handler.hpp +++ b/sycl/include/CL/sycl/handler.hpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -825,6 +826,21 @@ class __SYCL_EXPORT handler { MCGType = detail::CG::CODEPLAY_HOST_TASK; } + template + typename std::enable_if< + detail::check_fn_signature::type, + void(interop_handle)>::value>::type + codeplay_host_task(FuncT &&Func) { + throwIfActionIsCreated(); + + MNDRDesc.set(range<1>(1)); + MArgs = std::move(MAssociatedAccesors); + + MHostTask.reset(new detail::HostTask(std::move(Func))); + + MCGType = detail::CG::CODEPLAY_HOST_TASK; + } + /// Defines and invokes a SYCL kernel function for the specified range and /// offset. /// diff --git a/sycl/include/CL/sycl/interop_handle.hpp b/sycl/include/CL/sycl/interop_handle.hpp new file mode 100644 index 0000000000000..ba8704aa25b41 --- /dev/null +++ b/sycl/include/CL/sycl/interop_handle.hpp @@ -0,0 +1,117 @@ +//==------------ interop_handle.hpp --- SYCL interop handle ----------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include + +#include + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { + +namespace detail { +class AccessorBaseHost; +class ExecCGCommand; +class DispatchHostTask; +} // namespace detail + +template +class accessor; + +class interop_handle { +public: + /// Receives a SYCL accessor that has been defined is a requirement for the + /// command group, and returns the underlying OpenCL memory object that is + /// used by the SYCL runtime. If the accessor passed as parameter is not part + /// of the command group requirements (e.g. it is an unregistered placeholder + /// accessor), the exception `cl::sycl::invalid_object` is thrown + /// asynchronously. + template + typename std::enable_if::type + get_native_mem(const accessor &Acc) const { +#ifndef __SYCL_DEVICE_ONLY__ + // employ reinterpret_cast instead of static_cast due to cycle in includes + // involving CL/sycl/accessor.hpp + auto *AccBase = const_cast( + reinterpret_cast(&Acc)); + return getMemImpl(detail::getSyclObjImpl(*AccBase).get()); +#else + (void)Acc; + // we believe this won't be ever called on device side + return static_cast(0x0); +#endif + } + + template + typename std::enable_if::type + get_native_mem(const accessor &) const { + throw invalid_object_error("Getting memory object out of host accessor is " + "not allowed", + PI_INVALID_MEM_OBJECT); + } + + /// Returns an underlying OpenCL queue for the SYCL queue used to submit the + /// command group, or the fallback queue if this command-group is re-trying + /// execution on an OpenCL queue. The OpenCL command queue returned is + /// implementation-defined in cases where the SYCL queue maps to multiple + /// underlying OpenCL objects. It is responsibility of the SYCL runtime to + /// ensure the OpenCL queue returned is in a state that can be used to + /// dispatch work, and that other potential OpenCL command queues associated + /// with the same SYCL command queue are not executing commands while the host + /// task is executing. + cl_command_queue get_native_queue() const noexcept { return MQueue; } + + /// Returns an underlying OpenCL device associated with the SYCL queue used + /// to submit the command group, or the fallback queue if this command-group + /// is re-trying execution on an OpenCL queue. + cl_device_id get_native_device() const noexcept { return MDeviceId; } + + /// Returns an underlying OpenCL context associated with the SYCL queue used + /// to submit the command group, or the fallback queue if this command-group + /// is re-trying execution on an OpenCL queue. + cl_context get_native_context() const noexcept { return MContext; } + +private: + using ReqToMem = std::pair; + + template + friend class accessor; + friend class detail::ExecCGCommand; + friend class detail::DispatchHostTask; + +public: + // TODO set c-tor private + interop_handle(std::vector MemObjs, cl_command_queue Queue, + cl_device_id DeviceId, cl_context Context) + : MQueue(Queue), MDeviceId(DeviceId), MContext(Context), + MMemObjs(std::move(MemObjs)) {} + +private: + cl_mem getMemImpl(detail::Requirement *Req) const; + + cl_command_queue MQueue; + cl_device_id MDeviceId; + cl_context MContext; + std::vector MMemObjs; +}; + +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/source/CMakeLists.txt b/sycl/source/CMakeLists.txt index 17ff24db14b00..95501d4a616ab 100644 --- a/sycl/source/CMakeLists.txt +++ b/sycl/source/CMakeLists.txt @@ -150,6 +150,7 @@ set(SYCL_SOURCES "sampler.cpp" "stream.cpp" "spirv_ops.cpp" + "interop_handle.cpp" "$<$:detail/windows_pi.cpp>" "$<$,$>:detail/posix_pi.cpp>" ) diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index eb40bcdf82044..7281ca2747bec 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -159,6 +159,7 @@ getPiEvents(const std::vector &EventImpls) { class DispatchHostTask { ExecCGCommand *MThisCmd; + std::vector MReqToMem; void waitForEvents() const { std::map> @@ -187,7 +188,9 @@ class DispatchHostTask { } public: - DispatchHostTask(ExecCGCommand *ThisCmd) : MThisCmd{ThisCmd} {} + DispatchHostTask(ExecCGCommand *ThisCmd, + std::vector ReqToMem) + : MThisCmd{ThisCmd} {} void operator()() const { waitForEvents(); @@ -197,7 +200,17 @@ class DispatchHostTask { CGHostTask &HostTask = static_cast(MThisCmd->getCG()); // we're ready to call the user-defined lambda now - HostTask.MHostTask->call(); + if (HostTask.MHostTask->isInteropTask()) { + auto Queue = HostTask.MQueue->get(); + auto DeviceId = HostTask.MQueue->get_device().get(); + auto Context = HostTask.MQueue->get_context().get(); + + interop_handle IH{MReqToMem, Queue, DeviceId, Context}; + + HostTask.MHostTask->call(IH); + } else + HostTask.MHostTask->call(); + HostTask.MHostTask.reset(); // unblock user empty command here @@ -1913,8 +1926,21 @@ cl_int ExecCGCommand::enqueueImp() { } } + std::vector ReqToMem; + // Extract the Mem Objects for all Requirements, to ensure they are + // available if a user ask for them inside the interop task scope + const std::vector &HandlerReq = HostTask->MRequirements; + auto ReqToMemConv = [&ReqToMem, this](Requirement *Req) { + AllocaCommandBase *AllocaCmd = getAllocaForReq(Req); + auto MemArg = reinterpret_cast(AllocaCmd->getMemAllocation()); + interop_handle::ReqToMem ReqToMemEl = std::make_pair(Req, MemArg); + ReqToMem.emplace_back(ReqToMemEl); + }; + std::for_each(std::begin(HandlerReq), std::end(HandlerReq), ReqToMemConv); + std::sort(std::begin(ReqToMem), std::end(ReqToMem)); + MQueue->getThreadPool().submit( - std::move(DispatchHostTask(this))); + std::move(DispatchHostTask(this, std::move(ReqToMem)))); MShouldCompleteEventIfPossible = false; diff --git a/sycl/source/detail/scheduler/graph_builder.cpp b/sycl/source/detail/scheduler/graph_builder.cpp index 7bd3a346c096d..2309f14bd8a82 100644 --- a/sycl/source/detail/scheduler/graph_builder.cpp +++ b/sycl/source/detail/scheduler/graph_builder.cpp @@ -928,10 +928,11 @@ void Scheduler::GraphBuilder::connectDepEvent(Command *const Cmd, { std::unique_ptr HT(new detail::HostTask); std::unique_ptr ConnectCG(new detail::CGHostTask( - std::move(HT), /* Args = */ {}, /* ArgsStorage = */ {}, - /* AccStorage = */ {}, /* SharedPtrStorage = */ {}, - /* Requirements = */ {}, /* DepEvents = */ {DepEvent}, - CG::CODEPLAY_HOST_TASK, /* Payload */ {})); + std::move(HT), /* Queue = */ {}, /* Context = */ {}, /* Args = */ {}, + /* ArgsStorage = */ {}, /* AccStorage = */ {}, + /* SharedPtrStorage = */ {}, /* Requirements = */ {}, + /* DepEvents = */ {DepEvent}, CG::CODEPLAY_HOST_TASK, + /* Payload */ {})); ConnectCmd = new ExecCGCommand( std::move(ConnectCG), Scheduler::getInstance().getDefaultHostQueue()); } diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index 195cc9c1fc072..8699e8ba06c0e 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -85,9 +85,10 @@ event handler::finalize() { break; case detail::CG::CODEPLAY_HOST_TASK: CommandGroup.reset(new detail::CGHostTask( - std::move(MHostTask), std::move(MArgs), std::move(MArgsStorage), - std::move(MAccStorage), std::move(MSharedPtrStorage), - std::move(MRequirements), std::move(MEvents), MCGType, MCodeLoc)); + std::move(MHostTask), MQueue, MQueue->getContextImplPtr(), + std::move(MArgs), std::move(MArgsStorage), std::move(MAccStorage), + std::move(MSharedPtrStorage), std::move(MRequirements), + std::move(MEvents), MCGType, MCodeLoc)); break; case detail::CG::NONE: throw runtime_error("Command group submitted without a kernel or a " diff --git a/sycl/source/interop_handle.cpp b/sycl/source/interop_handle.cpp new file mode 100644 index 0000000000000..c1df4993c700b --- /dev/null +++ b/sycl/source/interop_handle.cpp @@ -0,0 +1,28 @@ +//==------------ interop_handle.cpp --- SYCL interop handle ----------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { + +cl_mem interop_handle::getMemImpl(detail::Requirement *Req) const { + auto Iter = std::find_if(std::begin(MMemObjs), std::end(MMemObjs), + [=](ReqToMem Elem) { return (Elem.first == Req); }); + + if (Iter == std::end(MMemObjs)) + throw("Invalid memory object used inside interop"); + + return detail::pi::cast(Iter->second); +} + +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/test/host-interop-task/interop-task-dependency.cpp b/sycl/test/host-interop-task/interop-task-dependency.cpp new file mode 100644 index 0000000000000..15153c27967c6 --- /dev/null +++ b/sycl/test/host-interop-task/interop-task-dependency.cpp @@ -0,0 +1,202 @@ +// RUN: %clangxx -fsycl %s -o %t.out %threads_lib +// RUN: %CPU_RUN_PLACEHOLDER SYCL_PI_TRACE=-1 %t.out 2>&1 %CPU_CHECK_PLACEHOLDER +// RUN: %GPU_RUN_PLACEHOLDER SYCL_PI_TRACE=-1 %t.out 2>&1 %GPU_CHECK_PLACEHOLDER +// RUN: %ACC_RUN_PLACEHOLDER SYCL_PI_TRACE=-1 %t.out 2>&1 %ACC_CHECK_PLACEHOLDER + +#include +#include +#include +#include +#include + +#include + +namespace S = cl::sycl; + +struct Context { + std::atomic_bool Flag; + S::queue &Queue; + S::buffer Buf1; + S::buffer Buf2; + S::buffer Buf3; + std::mutex Mutex; + std::condition_variable CV; +}; + +void Thread1Fn(Context *Ctx) { + // 0. initialize resulting buffer with apriori wrong result + { + S::accessor + Acc(Ctx->Buf2); + + for (size_t Idx = 0; Idx < Acc.get_count(); ++Idx) + Acc[Idx] = -1; + } + + { + S::accessor + Acc(Ctx->Buf2); + + for (size_t Idx = 0; Idx < Acc.get_count(); ++Idx) + Acc[Idx] = -2; + } + + { + S::accessor + Acc(Ctx->Buf3); + + for (size_t Idx = 0; Idx < Acc.get_count(); ++Idx) + Acc[Idx] = -3; + } + + // 1. submit task writing to buffer 1 + Ctx->Queue.submit([&](S::handler &CGH) { + S::accessor + GeneratorAcc(Ctx->Buf1, CGH); + + auto GeneratorKernel = [GeneratorAcc]() { + for (size_t Idx = 0; Idx < GeneratorAcc.get_count(); ++Idx) + GeneratorAcc[Idx] = Idx; + }; + + CGH.single_task(GeneratorKernel); + }); + + // 2. submit host task writing from buf 1 to buf 2 + auto HostTaskEvent = Ctx->Queue.submit([&](S::handler &CGH) { + S::accessor + CopierSrcAcc(Ctx->Buf1, CGH); + S::accessor + CopierDstAcc(Ctx->Buf2, CGH); + + auto CopierHostTask = [CopierSrcAcc, CopierDstAcc, &Ctx](S::interop_handle IH) { + // TODO write through interop handle objects + //(void)IH.get_native_mem(CopierSrcAcc); + //(void)IH.get_native_mem(CopierDstAcc); + (void)IH.get_native_queue(); + (void)IH.get_native_device(); + (void)IH.get_native_context(); + for (size_t Idx = 0; Idx < CopierDstAcc.get_count(); ++Idx) + CopierDstAcc[Idx] = CopierSrcAcc[Idx]; + + bool Expected = false; + bool Desired = true; + assert(Ctx->Flag.compare_exchange_strong(Expected, Desired)); + + { + std::lock_guard Lock(Ctx->Mutex); + Ctx->CV.notify_all(); + } + }; + + CGH.codeplay_host_task(CopierHostTask); + }); + + // 3. submit simple task to move data between two buffers + Ctx->Queue.submit([&](S::handler &CGH) { + S::accessor + SrcAcc(Ctx->Buf2, CGH); + S::accessor + DstAcc(Ctx->Buf3, CGH); + + CGH.depends_on(HostTaskEvent); + + auto CopierKernel = [SrcAcc, DstAcc]() { + for (size_t Idx = 0; Idx < DstAcc.get_count(); ++Idx) + DstAcc[Idx] = SrcAcc[Idx]; + }; + + CGH.single_task(CopierKernel); + }); + + // 4. check data in buffer #3 + { + S::accessor + Acc(Ctx->Buf3); + + bool Failure = false; + + for (size_t Idx = 0; Idx < Acc.get_count(); ++Idx) { + fprintf(stderr, "Third buffer [%3zu] = %i\n", Idx, Acc[Idx]); + + Failure |= (Acc[Idx] != Idx); + } + + assert(!Failure && "Invalid data in third buffer"); + } +} + +void Thread2Fn(Context *Ctx) { + std::unique_lock Lock(Ctx->Mutex); + + // T2.1. Wait until flag F is set eq true. + Ctx->CV.wait(Lock, [&Ctx] { return Ctx->Flag.load(); }); + + assert(Ctx->Flag.load()); +} + +void test() { + auto EH = [](S::exception_list EL) { + for (const std::exception_ptr &E : EL) { + throw E; + } + }; + + S::queue Queue(EH); + + Context Ctx{{false}, Queue, {10}, {10}, {10}, {}, {}}; + + // 0. setup: thread 1 T1: exec smth; thread 2 T2: waits; init flag F = false + auto A1 = std::async(std::launch::async, Thread1Fn, &Ctx); + auto A2 = std::async(std::launch::async, Thread2Fn, &Ctx); + + A1.get(); + A2.get(); + + assert(Ctx.Flag.load()); + + // 3. check via host accessor that buf 2 contains valid data + { + S::accessor + ResultAcc(Ctx.Buf2); + + bool failure = false; + for (size_t Idx = 0; Idx < ResultAcc.get_count(); ++Idx) { + fprintf(stderr, "Second buffer [%3zu] = %i\n", Idx, ResultAcc[Idx]); + + failure |= (ResultAcc[Idx] != Idx); + } + + assert(!failure && "Invalid data in result buffer"); + } +} + +int main() { + test(); + + return 0; +} + +// launch of GeneratorTask kernel +// CHECK:---> piKernelCreate( +// CHECK: GeneratorTask +// CHECK:---> piEnqueueKernelLaunch( +// prepare for host task +// CHECK:---> piEnqueueMemBufferMap( +// launch of CopierTask kernel +// CHECK:---> piKernelCreate( +// CHECK: CopierTask +// CHECK:---> piEnqueueKernelLaunch( +// TODO need to check for piEventsWait as "wait on dependencies of host task". +// At the same time this piEventsWait may occur anywhere after +// piEnqueueMemBufferMap ("prepare for host task").