diff --git a/sycl/source/detail/scheduler/graph_builder.cpp b/sycl/source/detail/scheduler/graph_builder.cpp index 116f3a3f98d14..9b0f79d24f731 100644 --- a/sycl/source/detail/scheduler/graph_builder.cpp +++ b/sycl/source/detail/scheduler/graph_builder.cpp @@ -1027,18 +1027,18 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord( AllocaCmd->MUsers.clear(); } - // Linked alloca's share dependencies. Unchain from deps linked alloca's. - // Any cmd of the alloca - linked_alloca may be used later on. + // Make sure the Linked Allocas are marked visited by the previous walk. + // Remove allocation commands from the users of their dependencies. for (AllocaCommandBase *AllocaCmd : AllocaCommands) { AllocaCommandBase *LinkedCmd = AllocaCmd->MLinkedAllocaCmd; if (LinkedCmd) { assert(LinkedCmd->MMarks.MVisited); - - for (DepDesc &Dep : AllocaCmd->MDeps) - if (Dep.MDepCommand) - Dep.MDepCommand->MUsers.erase(AllocaCmd); } + + for (DepDesc &Dep : AllocaCmd->MDeps) + if (Dep.MDepCommand) + Dep.MDepCommand->MUsers.erase(AllocaCmd); } // Traverse the graph using BFS diff --git a/sycl/unittests/scheduler/MemObjCommandCleanup.cpp b/sycl/unittests/scheduler/MemObjCommandCleanup.cpp index 17429831f4257..8eac4247b4a18 100644 --- a/sycl/unittests/scheduler/MemObjCommandCleanup.cpp +++ b/sycl/unittests/scheduler/MemObjCommandCleanup.cpp @@ -11,7 +11,7 @@ using namespace cl::sycl; -TEST_F(SchedulerTest, MemObjCommandCleanup) { +TEST_F(SchedulerTest, MemObjCommandCleanupAllocaUsers) { MockScheduler MS; buffer BufA(range<1>(1)); buffer BufB(range<1>(1)); @@ -51,3 +51,31 @@ TEST_F(SchedulerTest, MemObjCommandCleanup) { EXPECT_EQ(MockDirectUser->MDeps[0].MDepCommand, MockAllocaB.get()); EXPECT_TRUE(IndirectUserDeleted); } + +TEST_F(SchedulerTest, MemObjCommandCleanupAllocaDeps) { + MockScheduler MS; + buffer Buf(range<1>(1)); + detail::Requirement MockReq = getMockRequirement(Buf); + std::vector AuxCmds; + detail::MemObjRecord *MemObjRec = MS.getOrInsertMemObjRecord( + detail::getSyclObjImpl(MQueue), &MockReq, AuxCmds); + + // Create a fake alloca. + detail::AllocaCommand *MockAllocaCmd = + new detail::AllocaCommand(detail::getSyclObjImpl(MQueue), MockReq); + MemObjRec->MAllocaCommands.push_back(MockAllocaCmd); + + // Add another mock command and add MockAllocaCmd as its user. + MockCommand DepCmd(detail::getSyclObjImpl(MQueue), MockReq); + addEdge(MockAllocaCmd, &DepCmd, nullptr); + + // Check that DepCmd.MUsers size reflect the dependency properly. + ASSERT_EQ(DepCmd.MUsers.size(), 1U); + ASSERT_EQ(DepCmd.MUsers.count(MockAllocaCmd), 1U); + + MS.cleanupCommandsForRecord(MemObjRec); + MS.removeRecordForMemObj(detail::getSyclObjImpl(Buf).get()); + + // Check that DepCmd has its MUsers field cleared. + ASSERT_EQ(DepCmd.MUsers.size(), 0U); +}