Skip to content

Commit bc8f0a4

Browse files
[SYCL] Exit early while trying to enqueue blocked tasks (#2347)
1 parent 0c220ca commit bc8f0a4

File tree

5 files changed

+137
-40
lines changed

5 files changed

+137
-40
lines changed

sycl/source/detail/scheduler/commands.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,18 @@ class Command {
122122
/// \param Blocking if this argument is true, function will wait for the
123123
/// command to be unblocked before calling enqueueImp.
124124
/// \return true if the command is enqueued.
125-
bool enqueue(EnqueueResultT &EnqueueResult, BlockingT Blocking);
125+
virtual bool enqueue(EnqueueResultT &EnqueueResult, BlockingT Blocking);
126126

127127
bool isFinished();
128128

129129
bool isSuccessfullyEnqueued() const {
130130
return MEnqueueStatus == EnqueueResultT::SyclEnqueueSuccess;
131131
}
132132

133+
bool isEnqueueBlocked() const {
134+
return MEnqueueStatus == EnqueueResultT::SyclEnqueueBlocked;
135+
}
136+
133137
std::shared_ptr<queue_impl> getQueue() const { return MQueue; }
134138

135139
std::shared_ptr<event_impl> getEvent() const { return MEvent; }

sycl/source/detail/scheduler/graph_processor.cpp

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -58,33 +58,19 @@ bool Scheduler::GraphProcessor::enqueueCommand(Command *Cmd,
5858
if (!Cmd || Cmd->isSuccessfullyEnqueued())
5959
return true;
6060

61-
// Indicates whether dependency cannot be enqueued
62-
bool BlockedByDep = false;
61+
// Exit early if the command is blocked and the enqueue type is non-blocking
62+
if (Cmd->isEnqueueBlocked() && !Blocking) {
63+
EnqueueResult = EnqueueResultT(EnqueueResultT::SyclEnqueueBlocked, Cmd);
64+
return false;
65+
}
6366

67+
// Recursively enqueue all the dependencies first and
68+
// exit immediately if any of the commands cannot be enqueued.
6469
for (DepDesc &Dep : Cmd->MDeps) {
65-
const bool Enqueued =
66-
enqueueCommand(Dep.MDepCommand, EnqueueResult, Blocking);
67-
if (!Enqueued)
68-
switch (EnqueueResult.MResult) {
69-
case EnqueueResultT::SyclEnqueueFailed:
70-
default:
71-
// Exit immediately if a command fails to avoid enqueueing commands
72-
// result of which will be discarded.
73-
return false;
74-
case EnqueueResultT::SyclEnqueueBlocked:
75-
// If some dependency is blocked from enqueueing remember that, but
76-
// try to enqueue other dependencies(that can be ready for
77-
// enqueueing).
78-
BlockedByDep = true;
79-
break;
80-
}
70+
if (!enqueueCommand(Dep.MDepCommand, EnqueueResult, Blocking))
71+
return false;
8172
}
8273

83-
// Exit if some command is blocked from enqueueing, the EnqueueResult is set
84-
// by the latest dependency which was blocked.
85-
if (BlockedByDep)
86-
return false;
87-
8874
return Cmd->enqueue(EnqueueResult, Blocking);
8975
}
9076

sycl/unittests/scheduler/BlockedCommands.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "SchedulerTestUtils.hpp"
1111

1212
using namespace cl::sycl;
13+
using namespace testing;
1314

1415
TEST_F(SchedulerTest, BlockedCommands) {
1516
MockCommand MockCmd(detail::getSyclObjImpl(MQueue));
@@ -45,3 +46,87 @@ TEST_F(SchedulerTest, BlockedCommands) {
4546
Res.MResult == detail::EnqueueResultT::SyclEnqueueSuccess)
4647
<< "The command is expected to be successfully enqueued.\n";
4748
}
49+
50+
TEST_F(SchedulerTest, DontEnqueueDepsIfOneOfThemIsBlocked) {
51+
MockCommand A(detail::getSyclObjImpl(MQueue));
52+
A.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady;
53+
A.MIsBlockable = true;
54+
A.MRetVal = CL_SUCCESS;
55+
56+
MockCommand B(detail::getSyclObjImpl(MQueue));
57+
B.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady;
58+
B.MIsBlockable = true;
59+
B.MRetVal = CL_SUCCESS;
60+
61+
MockCommand C(detail::getSyclObjImpl(MQueue));
62+
C.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueBlocked;
63+
C.MIsBlockable = true;
64+
65+
MockCommand D(detail::getSyclObjImpl(MQueue));
66+
D.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady;
67+
D.MIsBlockable = true;
68+
D.MRetVal = CL_SUCCESS;
69+
70+
addEdge(&A, &B, nullptr);
71+
addEdge(&A, &C, nullptr);
72+
addEdge(&A, &D, nullptr);
73+
74+
// We have such a graph:
75+
//
76+
// A
77+
// / | \
78+
// B C D
79+
//
80+
// If C is blocked, we should not try to enqueue D.
81+
82+
EXPECT_CALL(A, enqueue(_, _)).Times(0);
83+
EXPECT_CALL(B, enqueue(_, _)).Times(1);
84+
EXPECT_CALL(C, enqueue(_, _)).Times(0);
85+
EXPECT_CALL(D, enqueue(_, _)).Times(0);
86+
87+
detail::EnqueueResultT Res;
88+
bool Enqueued = MockScheduler::enqueueCommand(&A, Res, detail::NON_BLOCKING);
89+
ASSERT_FALSE(Enqueued) << "Blocked command should not be enqueued\n";
90+
ASSERT_EQ(detail::EnqueueResultT::SyclEnqueueBlocked, Res.MResult)
91+
<< "Result of enqueueing blocked command should be BLOCKED.\n";
92+
ASSERT_EQ(&C, Res.MCmd) << "Expected different failed command.\n";
93+
}
94+
95+
TEST_F(SchedulerTest, EnqueueBlockedCommandEarlyExit) {
96+
MockCommand A(detail::getSyclObjImpl(MQueue));
97+
A.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueBlocked;
98+
A.MIsBlockable = true;
99+
100+
MockCommand B(detail::getSyclObjImpl(MQueue));
101+
B.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady;
102+
B.MRetVal = CL_OUT_OF_RESOURCES;
103+
104+
addEdge(&A, &B, nullptr);
105+
106+
// We have such a graph:
107+
//
108+
// A -> B
109+
//
110+
// If A is blocked, we should not try to enqueue B.
111+
112+
EXPECT_CALL(A, enqueue(_, _)).Times(0);
113+
EXPECT_CALL(B, enqueue(_, _)).Times(0);
114+
115+
detail::EnqueueResultT Res;
116+
bool Enqueued = MockScheduler::enqueueCommand(&A, Res, detail::NON_BLOCKING);
117+
ASSERT_FALSE(Enqueued) << "Blocked command should not be enqueued\n";
118+
ASSERT_EQ(detail::EnqueueResultT::SyclEnqueueBlocked, Res.MResult)
119+
<< "Result of enqueueing blocked command should be BLOCKED.\n";
120+
ASSERT_EQ(&A, Res.MCmd) << "Expected different failed command.\n";
121+
122+
// But if the enqueue type is blocking we should not exit early.
123+
124+
EXPECT_CALL(A, enqueue(_, _)).Times(0);
125+
EXPECT_CALL(B, enqueue(_, _)).Times(1);
126+
127+
Enqueued = MockScheduler::enqueueCommand(&A, Res, detail::BLOCKING);
128+
ASSERT_FALSE(Enqueued) << "Blocked command should not be enqueued\n";
129+
ASSERT_EQ(detail::EnqueueResultT::SyclEnqueueFailed, Res.MResult)
130+
<< "Result of enqueueing blocked command should be BLOCKED.\n";
131+
ASSERT_EQ(&B, Res.MCmd) << "Expected different failed command.\n";
132+
}

sycl/unittests/scheduler/LeafLimit.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,42 +21,45 @@ using namespace cl::sycl;
2121
// overflowed.
2222
TEST_F(SchedulerTest, LeafLimit) {
2323
MockScheduler MS;
24+
std::vector<std::unique_ptr<MockCommand>> LeavesToAdd;
25+
std::unique_ptr<MockCommand> MockDepCmd;
2426

2527
buffer<int, 1> Buf(range<1>(1));
2628
detail::Requirement MockReq = getMockRequirement(Buf);
27-
MockCommand *MockDepCmd =
28-
new MockCommand(detail::getSyclObjImpl(MQueue), MockReq);
29+
30+
MockDepCmd =
31+
std::make_unique<MockCommand>(detail::getSyclObjImpl(MQueue), MockReq);
2932
detail::MemObjRecord *Rec =
3033
MS.getOrInsertMemObjRecord(detail::getSyclObjImpl(MQueue), &MockReq);
3134

3235
// Create commands that will be added as leaves exceeding the limit by 1
33-
std::vector<MockCommand *> LeavesToAdd;
3436
for (std::size_t i = 0; i < Rec->MWriteLeaves.genericCommandsCapacity() + 1;
3537
++i) {
3638
LeavesToAdd.push_back(
37-
new MockCommand(detail::getSyclObjImpl(MQueue), MockReq));
39+
std::make_unique<MockCommand>(detail::getSyclObjImpl(MQueue), MockReq));
3840
}
3941
// Create edges: all soon-to-be leaves are direct users of MockDep
40-
for (auto Leaf : LeavesToAdd) {
41-
MockDepCmd->addUser(Leaf);
42-
Leaf->addDep(detail::DepDesc{MockDepCmd, Leaf->getRequirement(), nullptr});
42+
for (auto &Leaf : LeavesToAdd) {
43+
MockDepCmd->addUser(Leaf.get());
44+
Leaf->addDep(
45+
detail::DepDesc{MockDepCmd.get(), Leaf->getRequirement(), nullptr});
4346
}
4447
// Add edges as leaves and exceed the leaf limit
45-
for (auto LeafPtr : LeavesToAdd) {
46-
MS.addNodeToLeaves(Rec, LeafPtr);
48+
for (auto &LeafPtr : LeavesToAdd) {
49+
MS.addNodeToLeaves(Rec, LeafPtr.get());
4750
}
4851
// Check that the oldest leaf has been removed from the leaf list
4952
// and added as a dependency of the newest one instead
5053
const detail::CircularBuffer<detail::Command *> &Leaves =
5154
Rec->MWriteLeaves.getGenericCommands();
52-
ASSERT_TRUE(std::find(Leaves.begin(), Leaves.end(), LeavesToAdd.front()) ==
53-
Leaves.end());
55+
ASSERT_TRUE(std::find(Leaves.begin(), Leaves.end(),
56+
LeavesToAdd.front().get()) == Leaves.end());
5457
for (std::size_t i = 1; i < LeavesToAdd.size(); ++i) {
55-
assert(std::find(Leaves.begin(), Leaves.end(), LeavesToAdd[i]) !=
58+
assert(std::find(Leaves.begin(), Leaves.end(), LeavesToAdd[i].get()) !=
5659
Leaves.end());
5760
}
58-
MockCommand *OldestLeaf = LeavesToAdd.front();
59-
MockCommand *NewestLeaf = LeavesToAdd.back();
61+
MockCommand *OldestLeaf = LeavesToAdd.front().get();
62+
MockCommand *NewestLeaf = LeavesToAdd.back().get();
6063
ASSERT_EQ(OldestLeaf->MUsers.size(), 1U);
6164
EXPECT_GT(OldestLeaf->MUsers.count(NewestLeaf), 0U);
6265
ASSERT_EQ(NewestLeaf->MDeps.size(), 2U);

sycl/unittests/scheduler/SchedulerTestUtils.hpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include <detail/scheduler/scheduler.hpp>
1414

1515
#include <functional>
16+
#include <gmock/gmock.h>
17+
1618
// This header contains a few common classes/methods used in
1719
// execution graph testing.
1820

@@ -24,12 +26,22 @@ class MockCommand : public cl::sycl::detail::Command {
2426
cl::sycl::detail::Requirement Req,
2527
cl::sycl::detail::Command::CommandType Type =
2628
cl::sycl::detail::Command::RUN_CG)
27-
: Command{Type, Queue}, MRequirement{std::move(Req)} {}
29+
: Command{Type, Queue}, MRequirement{std::move(Req)} {
30+
using namespace testing;
31+
ON_CALL(*this, enqueue(_, _))
32+
.WillByDefault(Invoke(this, &MockCommand::enqueueOrigin));
33+
EXPECT_CALL(*this, enqueue(_, _)).Times(AnyNumber());
34+
}
2835

2936
MockCommand(cl::sycl::detail::QueueImplPtr Queue,
3037
cl::sycl::detail::Command::CommandType Type =
3138
cl::sycl::detail::Command::RUN_CG)
32-
: Command{Type, Queue}, MRequirement{std::move(getMockRequirement())} {}
39+
: Command{Type, Queue}, MRequirement{std::move(getMockRequirement())} {
40+
using namespace testing;
41+
ON_CALL(*this, enqueue(_, _))
42+
.WillByDefault(Invoke(this, &MockCommand::enqueueOrigin));
43+
EXPECT_CALL(*this, enqueue(_, _)).Times(AnyNumber());
44+
}
3345

3446
void printDot(std::ostream &) const override {}
3547
void emitInstrumentationData() override {}
@@ -40,6 +52,13 @@ class MockCommand : public cl::sycl::detail::Command {
4052

4153
cl_int enqueueImp() override { return MRetVal; }
4254

55+
MOCK_METHOD2(enqueue, bool(cl::sycl::detail::EnqueueResultT &,
56+
cl::sycl::detail::BlockingT));
57+
bool enqueueOrigin(cl::sycl::detail::EnqueueResultT &EnqueueResult,
58+
cl::sycl::detail::BlockingT Blocking) {
59+
return cl::sycl::detail::Command::enqueue(EnqueueResult, Blocking);
60+
}
61+
4362
cl_int MRetVal = CL_SUCCESS;
4463

4564
void waitForEventsCall(

0 commit comments

Comments
 (0)