diff --git a/sycl/source/detail/event_impl.cpp b/sycl/source/detail/event_impl.cpp index e22a7385e22b9..fe1f63899798c 100644 --- a/sycl/source/detail/event_impl.cpp +++ b/sycl/source/detail/event_impl.cpp @@ -38,8 +38,10 @@ void event_impl::initContextIfNeeded() { return; const device SyclDevice; - this->setContextImpl( - detail::queue_impl::getDefaultOrNew(*detail::getSyclObjImpl(SyclDevice))); + MIsHostEvent = false; + MContext = + detail::queue_impl::getDefaultOrNew(*detail::getSyclObjImpl(SyclDevice)); + assert(MContext); } event_impl::~event_impl() { @@ -140,9 +142,10 @@ void event_impl::setHandle(const ur_event_handle_t &UREvent) { MEvent.store(UREvent); } -const ContextImplPtr &event_impl::getContextImpl() { +context_impl &event_impl::getContextImpl() { initContextIfNeeded(); - return MContext; + assert(MContext && "Trying to get context from a host event!"); + return *MContext; } const AdapterPtr &event_impl::getAdapter() { @@ -152,9 +155,9 @@ const AdapterPtr &event_impl::getAdapter() { void event_impl::setStateIncomplete() { MState = HES_NotComplete; } -void event_impl::setContextImpl(const ContextImplPtr &Context) { - MIsHostEvent = Context == nullptr; - MContext = Context; +void event_impl::setContextImpl(context_impl &Context) { + MIsHostEvent = false; + MContext = Context.shared_from_this(); } event_impl::event_impl(ur_event_handle_t Event, const context &SyclContext, @@ -178,7 +181,7 @@ event_impl::event_impl(ur_event_handle_t Event, const context &SyclContext, event_impl::event_impl(queue_impl &Queue, private_tag) : MQueue{Queue.weak_from_this()}, MIsProfilingEnabled{Queue.MIsProfilingEnabled} { - this->setContextImpl(Queue.getContextImplPtr()); + this->setContextImpl(Queue.getContextImpl()); MState.store(HES_Complete); } diff --git a/sycl/source/detail/event_impl.hpp b/sycl/source/detail/event_impl.hpp index 245e218ff4112..79d66f336a4f0 100644 --- a/sycl/source/detail/event_impl.hpp +++ b/sycl/source/detail/event_impl.hpp @@ -173,9 +173,7 @@ class event_impl : public std::enable_shared_from_this { void setHandle(const ur_event_handle_t &UREvent); /// Returns context that is associated with this event. - /// - /// \return a shared pointer to a valid context_impl. - const ContextImplPtr &getContextImpl(); + context_impl &getContextImpl(); /// \return the Adapter associated with the context of this event. /// Should be called when this is not a Host Event. @@ -183,11 +181,9 @@ class event_impl : public std::enable_shared_from_this { /// Associate event with the context. /// - /// Provided UrContext inside ContextImplPtr must be associated + /// Provided UrContext inside Context must be associated /// with the UrEvent object stored in this class - /// - /// @param Context is a shared pointer to an instance of valid context_impl. - void setContextImpl(const ContextImplPtr &Context); + void setContextImpl(context_impl &Context); /// Clear the event state void setStateIncomplete(); diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index a9a46a0199ec0..16650328fdb82 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -1037,7 +1037,7 @@ exec_graph_impl::enqueue(sycl::detail::queue_impl &Queue, auto CreateNewEvent([&]() { auto NewEvent = sycl::detail::event_impl::create_device_event(Queue); - NewEvent->setContextImpl(Queue.getContextImplPtr()); + NewEvent->setContextImpl(Queue.getContextImpl()); NewEvent->setStateIncomplete(); return NewEvent; }); diff --git a/sycl/source/detail/queue_impl.cpp b/sycl/source/detail/queue_impl.cpp index 8b610d46e80a8..f9a9f41450ec4 100644 --- a/sycl/source/detail/queue_impl.cpp +++ b/sycl/source/detail/queue_impl.cpp @@ -121,7 +121,7 @@ queue_impl::get_backend_info() const { static event prepareSYCLEventAssociatedWithQueue( const std::shared_ptr &QueueImpl) { auto EventImpl = detail::event_impl::create_device_event(*QueueImpl); - EventImpl->setContextImpl(detail::getSyclObjImpl(QueueImpl->get_context())); + EventImpl->setContextImpl(QueueImpl->getContextImpl()); EventImpl->setStateIncomplete(); return detail::createSyclObjFromImpl(EventImpl); } diff --git a/sycl/source/detail/reduction.cpp b/sycl/source/detail/reduction.cpp index 51bee331945e3..84a8722c96e76 100644 --- a/sycl/source/detail/reduction.cpp +++ b/sycl/source/detail/reduction.cpp @@ -208,7 +208,7 @@ __SYCL_EXPORT void addCounterInit(handler &CGH, std::shared_ptr &Queue, std::shared_ptr &Counter) { auto EventImpl = detail::event_impl::create_device_event(*Queue); - EventImpl->setContextImpl(detail::getSyclObjImpl(Queue->get_context())); + EventImpl->setContextImpl(Queue->getContextImpl()); EventImpl->setStateIncomplete(); ur_event_handle_t UREvent = nullptr; MemoryManager::fill_usm(Counter.get(), *Queue, sizeof(int), {0}, {}, diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index b0ec9f015eebe..b54d6b91aefc1 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -533,10 +533,8 @@ void Command::waitForEvents(queue_impl *Queue, RequiredEventsPerContext; for (const EventImplPtr &Event : EventImpls) { - ContextImplPtr Context = Event->getContextImpl(); - assert(Context.get() && - "Only non-host events are expected to be waited for here"); - RequiredEventsPerContext[Context.get()].push_back(Event); + context_impl &Context = Event->getContextImpl(); + RequiredEventsPerContext[&Context].push_back(Event); } for (auto &CtxWithEvents : RequiredEventsPerContext) { @@ -576,7 +574,7 @@ Command::Command( MEvent->setSubmittedQueue(MWorkerQueue); MEvent->setCommand(this); if (MQueue) - MEvent->setContextImpl(MQueue->getContextImplPtr()); + MEvent->setContextImpl(MQueue->getContextImpl()); MEvent->setStateIncomplete(); MEnqueueStatus = EnqueueResultT::SyclEnqueueReady; @@ -781,9 +779,9 @@ Command *Command::processDepEvent(EventImplPtr DepEvent, const DepDesc &Dep, Command *ConnectionCmd = nullptr; - ContextImplPtr DepEventContext = DepEvent->getContextImpl(); + context_impl &DepEventContext = DepEvent->getContextImpl(); // If contexts don't match we'll connect them using host task - if (DepEventContext != WorkerContext && WorkerContext) { + if (&DepEventContext != WorkerContext.get() && WorkerContext) { Scheduler::GraphBuilder &GB = Scheduler::getInstance().MGraphBuilder; ConnectionCmd = GB.connectDepEvent(this, DepEvent, Dep, ToCleanUp); } else @@ -1298,7 +1296,7 @@ ur_result_t ReleaseCommand::enqueueImp() { std::shared_ptr UnmapEventImpl = event_impl::create_device_event(*Queue); - UnmapEventImpl->setContextImpl(Queue->getContextImplPtr()); + UnmapEventImpl->setContextImpl(Queue->getContextImpl()); UnmapEventImpl->setStateIncomplete(); ur_event_handle_t UREvent = nullptr; @@ -1516,7 +1514,7 @@ MemCpyCommand::MemCpyCommand(Requirement SrcReq, MSrcReq(std::move(SrcReq)), MSrcAllocaCmd(SrcAllocaCmd), MDstReq(std::move(DstReq)), MDstAllocaCmd(DstAllocaCmd) { if (MSrcQueue) { - MEvent->setContextImpl(MSrcQueue->getContextImplPtr()); + MEvent->setContextImpl(MSrcQueue->getContextImpl()); } MWorkerQueue = !MQueue ? MSrcQueue : MQueue; @@ -1689,7 +1687,7 @@ MemCpyCommandHost::MemCpyCommandHost(Requirement SrcReq, MSrcReq(std::move(SrcReq)), MSrcAllocaCmd(SrcAllocaCmd), MDstReq(std::move(DstReq)), MDstPtr(DstPtr) { if (MSrcQueue) { - MEvent->setContextImpl(MSrcQueue->getContextImplPtr()); + MEvent->setContextImpl(MSrcQueue->getContextImpl()); } MWorkerQueue = !MQueue ? MSrcQueue : MQueue; diff --git a/sycl/source/detail/scheduler/graph_builder.cpp b/sycl/source/detail/scheduler/graph_builder.cpp index 1d7df4e86f6a1..8f7bfb29fb109 100644 --- a/sycl/source/detail/scheduler/graph_builder.cpp +++ b/sycl/source/detail/scheduler/graph_builder.cpp @@ -1221,7 +1221,7 @@ void Scheduler::GraphBuilder::removeRecordForMemObj(SYCLMemObjI *MemObject) { Command *Scheduler::GraphBuilder::connectDepEvent( Command *const Cmd, const EventImplPtr &DepEvent, const DepDesc &Dep, std::vector &ToCleanUp) { - assert(Cmd->getWorkerContext() != DepEvent->getContextImpl()); + assert(Cmd->getWorkerContext().get() != &DepEvent->getContextImpl()); // construct Host Task type command manually and make it depend on DepEvent ExecCGCommand *ConnectCmd = nullptr; diff --git a/sycl/source/detail/scheduler/scheduler.cpp b/sycl/source/detail/scheduler/scheduler.cpp index b84500dc65a96..66d3974674282 100644 --- a/sycl/source/detail/scheduler/scheduler.cpp +++ b/sycl/source/detail/scheduler/scheduler.cpp @@ -691,7 +691,7 @@ bool Scheduler::CheckEventReadiness(context_impl &Context, return SyclEventImplPtr->isCompleted(); } // Cross-context dependencies can't be passed to the backend directly. - if (SyclEventImplPtr->getContextImpl().get() != &Context) + if (&SyclEventImplPtr->getContextImpl() != &Context) return false; // A nullptr here means that the commmand does not produce a UR event or it diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index 300a09a17b128..01b8ed88b94c6 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -610,7 +610,7 @@ event handler::finalize() { detail::queue_impl &Queue = impl->get_queue(); LastEventImpl->setQueue(Queue); LastEventImpl->setWorkerQueue(Queue.weak_from_this()); - LastEventImpl->setContextImpl(impl->get_context().shared_from_this()); + LastEventImpl->setContextImpl(impl->get_context()); LastEventImpl->setStateIncomplete(); LastEventImpl->setSubmissionTime(); diff --git a/sycl/unittests/scheduler/QueueFlushing.cpp b/sycl/unittests/scheduler/QueueFlushing.cpp index ea03cc8d61474..6f5e1fb4364da 100644 --- a/sycl/unittests/scheduler/QueueFlushing.cpp +++ b/sycl/unittests/scheduler/QueueFlushing.cpp @@ -150,7 +150,7 @@ TEST_F(SchedulerTest, QueueFlushing) { access::mode::read_write}; std::shared_ptr DepEvent = detail::event_impl::create_device_event(QueueImplB); - DepEvent->setContextImpl(QueueImplB.getContextImplPtr()); + DepEvent->setContextImpl(QueueImplB.getContextImpl()); ur_event_handle_t UREvent = mock::createDummyHandle(); @@ -170,7 +170,7 @@ TEST_F(SchedulerTest, QueueFlushing) { queue TempQueue{Ctx, default_selector_v}; detail::queue_impl &TempQueueImpl = *detail::getSyclObjImpl(TempQueue); DepEvent = detail::event_impl::create_device_event(TempQueueImpl); - DepEvent->setContextImpl(TempQueueImpl.getContextImplPtr()); + DepEvent->setContextImpl(TempQueueImpl.getContextImpl()); ur_event_handle_t UREvent = mock::createDummyHandle();