diff --git a/sycl/source/detail/scheduler/scheduler.cpp b/sycl/source/detail/scheduler/scheduler.cpp index d896dfc6be08b..c4f2ab6534551 100644 --- a/sycl/source/detail/scheduler/scheduler.cpp +++ b/sycl/source/detail/scheduler/scheduler.cpp @@ -156,18 +156,33 @@ void Scheduler::cleanupFinishedCommands(EventImplPtr FinishedEvent) { } void Scheduler::removeMemoryObject(detail::SYCLMemObjI *MemObj) { + MemObjRecord *Record = nullptr; std::unique_lock Lock(MGraphLock, std::defer_lock); - lockSharedTimedMutex(Lock); - MemObjRecord *Record = MGraphBuilder.getMemObjRecord(MemObj); - if (!Record) - // No operations were performed on the mem object - return; + { + lockSharedTimedMutex(Lock); + + Record = MGraphBuilder.getMemObjRecord(MemObj); + if (!Record) + // No operations were performed on the mem object + return; - waitForRecordToFinish(Record); - MGraphBuilder.decrementLeafCountersForRecord(Record); - MGraphBuilder.cleanupCommandsForRecord(Record); - MGraphBuilder.removeRecordForMemObj(MemObj); + Lock.unlock(); + } + + { + // This only needs a shared mutex as it only involves enqueueing and + // awaiting for events + std::shared_lock Lock(MGraphLock); + waitForRecordToFinish(Record); + } + + { + lockSharedTimedMutex(Lock); + MGraphBuilder.decrementLeafCountersForRecord(Record); + MGraphBuilder.cleanupCommandsForRecord(Record); + MGraphBuilder.removeRecordForMemObj(MemObj); + } } EventImplPtr Scheduler::addHostAccessor(Requirement *Req) { diff --git a/sycl/test/host-interop-task/host-task-failure.cpp b/sycl/test/host-interop-task/host-task-failure.cpp new file mode 100644 index 0000000000000..423082b53198d --- /dev/null +++ b/sycl/test/host-interop-task/host-task-failure.cpp @@ -0,0 +1,58 @@ +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out +// RUN: %ACC_RUN_PLACEHOLDER %t.out + +#include + +using namespace cl::sycl; +using namespace cl::sycl::access; + +static constexpr size_t BUFFER_SIZE = 1024; + +template +class Modifier; + +template +class Init; + +template +void copy(buffer &Src, buffer &Dst, queue &Q) { + Q.submit([&](handler &CGH) { + auto SrcA = Src.template get_access(CGH); + auto DstA = Dst.template get_access(CGH); + + CGH.codeplay_host_task([=]() { + for (size_t Idx = 0; Idx < SrcA.get_count(); ++Idx) + DstA[Idx] = SrcA[Idx]; + }); + }); +} + +template +void init(buffer &B1, buffer &B2, queue &Q) { + Q.submit([&](handler &CGH) { + auto Acc1 = B1.template get_access(CGH); + auto Acc2 = B2.template get_access(CGH); + + CGH.parallel_for>(BUFFER_SIZE, [=](item<1> Id) { + Acc1[Id] = -1; + Acc2[Id] = -2; + }); + }); +} + +void test() { + queue Q; + buffer Buffer1{BUFFER_SIZE}; + buffer Buffer2{BUFFER_SIZE}; + + init(Buffer1, Buffer2, Q); + + copy(Buffer1, Buffer2, Q); +} + +int main() { + test(); + return 0; +}