diff --git a/sycl/source/detail/scheduler/commands.hpp b/sycl/source/detail/scheduler/commands.hpp index 1b551bce0c470..7406f82331bf4 100644 --- a/sycl/source/detail/scheduler/commands.hpp +++ b/sycl/source/detail/scheduler/commands.hpp @@ -122,7 +122,7 @@ class Command { /// \param Blocking if this argument is true, function will wait for the /// command to be unblocked before calling enqueueImp. /// \return true if the command is enqueued. - bool enqueue(EnqueueResultT &EnqueueResult, BlockingT Blocking); + virtual bool enqueue(EnqueueResultT &EnqueueResult, BlockingT Blocking); bool isFinished(); @@ -130,6 +130,10 @@ class Command { return MEnqueueStatus == EnqueueResultT::SyclEnqueueSuccess; } + bool isEnqueueBlocked() const { + return MEnqueueStatus == EnqueueResultT::SyclEnqueueBlocked; + } + std::shared_ptr getQueue() const { return MQueue; } std::shared_ptr getEvent() const { return MEvent; } diff --git a/sycl/source/detail/scheduler/graph_processor.cpp b/sycl/source/detail/scheduler/graph_processor.cpp index 7b9f10efef295..bc7f813069f39 100644 --- a/sycl/source/detail/scheduler/graph_processor.cpp +++ b/sycl/source/detail/scheduler/graph_processor.cpp @@ -58,33 +58,19 @@ bool Scheduler::GraphProcessor::enqueueCommand(Command *Cmd, if (!Cmd || Cmd->isSuccessfullyEnqueued()) return true; - // Indicates whether dependency cannot be enqueued - bool BlockedByDep = false; + // Exit early if the command is blocked and the enqueue type is non-blocking + if (Cmd->isEnqueueBlocked() && !Blocking) { + EnqueueResult = EnqueueResultT(EnqueueResultT::SyclEnqueueBlocked, Cmd); + return false; + } + // Recursively enqueue all the dependencies first and + // exit immediately if any of the commands cannot be enqueued. for (DepDesc &Dep : Cmd->MDeps) { - const bool Enqueued = - enqueueCommand(Dep.MDepCommand, EnqueueResult, Blocking); - if (!Enqueued) - switch (EnqueueResult.MResult) { - case EnqueueResultT::SyclEnqueueFailed: - default: - // Exit immediately if a command fails to avoid enqueueing commands - // result of which will be discarded. - return false; - case EnqueueResultT::SyclEnqueueBlocked: - // If some dependency is blocked from enqueueing remember that, but - // try to enqueue other dependencies(that can be ready for - // enqueueing). - BlockedByDep = true; - break; - } + if (!enqueueCommand(Dep.MDepCommand, EnqueueResult, Blocking)) + return false; } - // Exit if some command is blocked from enqueueing, the EnqueueResult is set - // by the latest dependency which was blocked. - if (BlockedByDep) - return false; - return Cmd->enqueue(EnqueueResult, Blocking); } diff --git a/sycl/unittests/scheduler/BlockedCommands.cpp b/sycl/unittests/scheduler/BlockedCommands.cpp index 6aa0ae05b92d7..5ed22606d495e 100644 --- a/sycl/unittests/scheduler/BlockedCommands.cpp +++ b/sycl/unittests/scheduler/BlockedCommands.cpp @@ -10,6 +10,7 @@ #include "SchedulerTestUtils.hpp" using namespace cl::sycl; +using namespace testing; TEST_F(SchedulerTest, BlockedCommands) { MockCommand MockCmd(detail::getSyclObjImpl(MQueue)); @@ -45,3 +46,87 @@ TEST_F(SchedulerTest, BlockedCommands) { Res.MResult == detail::EnqueueResultT::SyclEnqueueSuccess) << "The command is expected to be successfully enqueued.\n"; } + +TEST_F(SchedulerTest, DontEnqueueDepsIfOneOfThemIsBlocked) { + MockCommand A(detail::getSyclObjImpl(MQueue)); + A.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady; + A.MIsBlockable = true; + A.MRetVal = CL_SUCCESS; + + MockCommand B(detail::getSyclObjImpl(MQueue)); + B.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady; + B.MIsBlockable = true; + B.MRetVal = CL_SUCCESS; + + MockCommand C(detail::getSyclObjImpl(MQueue)); + C.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueBlocked; + C.MIsBlockable = true; + + MockCommand D(detail::getSyclObjImpl(MQueue)); + D.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady; + D.MIsBlockable = true; + D.MRetVal = CL_SUCCESS; + + addEdge(&A, &B, nullptr); + addEdge(&A, &C, nullptr); + addEdge(&A, &D, nullptr); + + // We have such a graph: + // + // A + // / | \ + // B C D + // + // If C is blocked, we should not try to enqueue D. + + EXPECT_CALL(A, enqueue(_, _)).Times(0); + EXPECT_CALL(B, enqueue(_, _)).Times(1); + EXPECT_CALL(C, enqueue(_, _)).Times(0); + EXPECT_CALL(D, enqueue(_, _)).Times(0); + + detail::EnqueueResultT Res; + bool Enqueued = MockScheduler::enqueueCommand(&A, Res, detail::NON_BLOCKING); + ASSERT_FALSE(Enqueued) << "Blocked command should not be enqueued\n"; + ASSERT_EQ(detail::EnqueueResultT::SyclEnqueueBlocked, Res.MResult) + << "Result of enqueueing blocked command should be BLOCKED.\n"; + ASSERT_EQ(&C, Res.MCmd) << "Expected different failed command.\n"; +} + +TEST_F(SchedulerTest, EnqueueBlockedCommandEarlyExit) { + MockCommand A(detail::getSyclObjImpl(MQueue)); + A.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueBlocked; + A.MIsBlockable = true; + + MockCommand B(detail::getSyclObjImpl(MQueue)); + B.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady; + B.MRetVal = CL_OUT_OF_RESOURCES; + + addEdge(&A, &B, nullptr); + + // We have such a graph: + // + // A -> B + // + // If A is blocked, we should not try to enqueue B. + + EXPECT_CALL(A, enqueue(_, _)).Times(0); + EXPECT_CALL(B, enqueue(_, _)).Times(0); + + detail::EnqueueResultT Res; + bool Enqueued = MockScheduler::enqueueCommand(&A, Res, detail::NON_BLOCKING); + ASSERT_FALSE(Enqueued) << "Blocked command should not be enqueued\n"; + ASSERT_EQ(detail::EnqueueResultT::SyclEnqueueBlocked, Res.MResult) + << "Result of enqueueing blocked command should be BLOCKED.\n"; + ASSERT_EQ(&A, Res.MCmd) << "Expected different failed command.\n"; + + // But if the enqueue type is blocking we should not exit early. + + EXPECT_CALL(A, enqueue(_, _)).Times(0); + EXPECT_CALL(B, enqueue(_, _)).Times(1); + + Enqueued = MockScheduler::enqueueCommand(&A, Res, detail::BLOCKING); + ASSERT_FALSE(Enqueued) << "Blocked command should not be enqueued\n"; + ASSERT_EQ(detail::EnqueueResultT::SyclEnqueueFailed, Res.MResult) + << "Result of enqueueing blocked command should be BLOCKED.\n"; + ASSERT_EQ(&B, Res.MCmd) << "Expected different failed command.\n"; +} diff --git a/sycl/unittests/scheduler/LeafLimit.cpp b/sycl/unittests/scheduler/LeafLimit.cpp index c1e22f37bf3f6..d840099a0048d 100644 --- a/sycl/unittests/scheduler/LeafLimit.cpp +++ b/sycl/unittests/scheduler/LeafLimit.cpp @@ -21,42 +21,45 @@ using namespace cl::sycl; // overflowed. TEST_F(SchedulerTest, LeafLimit) { MockScheduler MS; + std::vector> LeavesToAdd; + std::unique_ptr MockDepCmd; buffer Buf(range<1>(1)); detail::Requirement MockReq = getMockRequirement(Buf); - MockCommand *MockDepCmd = - new MockCommand(detail::getSyclObjImpl(MQueue), MockReq); + + MockDepCmd = + std::make_unique(detail::getSyclObjImpl(MQueue), MockReq); detail::MemObjRecord *Rec = MS.getOrInsertMemObjRecord(detail::getSyclObjImpl(MQueue), &MockReq); // Create commands that will be added as leaves exceeding the limit by 1 - std::vector LeavesToAdd; for (std::size_t i = 0; i < Rec->MWriteLeaves.genericCommandsCapacity() + 1; ++i) { LeavesToAdd.push_back( - new MockCommand(detail::getSyclObjImpl(MQueue), MockReq)); + std::make_unique(detail::getSyclObjImpl(MQueue), MockReq)); } // Create edges: all soon-to-be leaves are direct users of MockDep - for (auto Leaf : LeavesToAdd) { - MockDepCmd->addUser(Leaf); - Leaf->addDep(detail::DepDesc{MockDepCmd, Leaf->getRequirement(), nullptr}); + for (auto &Leaf : LeavesToAdd) { + MockDepCmd->addUser(Leaf.get()); + Leaf->addDep( + detail::DepDesc{MockDepCmd.get(), Leaf->getRequirement(), nullptr}); } // Add edges as leaves and exceed the leaf limit - for (auto LeafPtr : LeavesToAdd) { - MS.addNodeToLeaves(Rec, LeafPtr); + for (auto &LeafPtr : LeavesToAdd) { + MS.addNodeToLeaves(Rec, LeafPtr.get()); } // Check that the oldest leaf has been removed from the leaf list // and added as a dependency of the newest one instead const detail::CircularBuffer &Leaves = Rec->MWriteLeaves.getGenericCommands(); - ASSERT_TRUE(std::find(Leaves.begin(), Leaves.end(), LeavesToAdd.front()) == - Leaves.end()); + ASSERT_TRUE(std::find(Leaves.begin(), Leaves.end(), + LeavesToAdd.front().get()) == Leaves.end()); for (std::size_t i = 1; i < LeavesToAdd.size(); ++i) { - assert(std::find(Leaves.begin(), Leaves.end(), LeavesToAdd[i]) != + assert(std::find(Leaves.begin(), Leaves.end(), LeavesToAdd[i].get()) != Leaves.end()); } - MockCommand *OldestLeaf = LeavesToAdd.front(); - MockCommand *NewestLeaf = LeavesToAdd.back(); + MockCommand *OldestLeaf = LeavesToAdd.front().get(); + MockCommand *NewestLeaf = LeavesToAdd.back().get(); ASSERT_EQ(OldestLeaf->MUsers.size(), 1U); EXPECT_GT(OldestLeaf->MUsers.count(NewestLeaf), 0U); ASSERT_EQ(NewestLeaf->MDeps.size(), 2U); diff --git a/sycl/unittests/scheduler/SchedulerTestUtils.hpp b/sycl/unittests/scheduler/SchedulerTestUtils.hpp index 7c91cebec9ef5..e5348f401236c 100644 --- a/sycl/unittests/scheduler/SchedulerTestUtils.hpp +++ b/sycl/unittests/scheduler/SchedulerTestUtils.hpp @@ -13,6 +13,8 @@ #include #include +#include + // This header contains a few common classes/methods used in // execution graph testing. @@ -24,12 +26,22 @@ class MockCommand : public cl::sycl::detail::Command { cl::sycl::detail::Requirement Req, cl::sycl::detail::Command::CommandType Type = cl::sycl::detail::Command::RUN_CG) - : Command{Type, Queue}, MRequirement{std::move(Req)} {} + : Command{Type, Queue}, MRequirement{std::move(Req)} { + using namespace testing; + ON_CALL(*this, enqueue(_, _)) + .WillByDefault(Invoke(this, &MockCommand::enqueueOrigin)); + EXPECT_CALL(*this, enqueue(_, _)).Times(AnyNumber()); + } MockCommand(cl::sycl::detail::QueueImplPtr Queue, cl::sycl::detail::Command::CommandType Type = cl::sycl::detail::Command::RUN_CG) - : Command{Type, Queue}, MRequirement{std::move(getMockRequirement())} {} + : Command{Type, Queue}, MRequirement{std::move(getMockRequirement())} { + using namespace testing; + ON_CALL(*this, enqueue(_, _)) + .WillByDefault(Invoke(this, &MockCommand::enqueueOrigin)); + EXPECT_CALL(*this, enqueue(_, _)).Times(AnyNumber()); + } void printDot(std::ostream &) const override {} void emitInstrumentationData() override {} @@ -40,6 +52,13 @@ class MockCommand : public cl::sycl::detail::Command { cl_int enqueueImp() override { return MRetVal; } + MOCK_METHOD2(enqueue, bool(cl::sycl::detail::EnqueueResultT &, + cl::sycl::detail::BlockingT)); + bool enqueueOrigin(cl::sycl::detail::EnqueueResultT &EnqueueResult, + cl::sycl::detail::BlockingT Blocking) { + return cl::sycl::detail::Command::enqueue(EnqueueResult, Blocking); + } + cl_int MRetVal = CL_SUCCESS; void waitForEventsCall(