Skip to content

Commit 2348227

Browse files
authored
[SYCL] Update graph constructor/finalize to current spec (#140)
- Add device and context params to graph constructor - Remove context from finalize - Minor changes to graph_impl to support this - Update all examples to use updated API - Tidied up ordering of graph_impl declarations a little
1 parent 7e580c5 commit 2348227

20 files changed

+90
-60
lines changed

sycl/include/sycl/ext/oneapi/experimental/graph.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ class __SYCL_EXPORT node {
5353
template <graph_state State = graph_state::modifiable>
5454
class __SYCL_EXPORT command_graph {
5555
public:
56-
command_graph(const property_list &propList = {});
56+
command_graph(const context &syclContext, const device &syclDevice,
57+
const property_list &propList = {});
5758

5859
// Adding empty node with [0..n] predecessors:
5960
node add(const std::vector<node> &dep = {}) { return add_impl(dep); }
@@ -67,8 +68,7 @@ class __SYCL_EXPORT command_graph {
6768
void make_edge(node sender, node receiver);
6869

6970
command_graph<graph_state::executable>
70-
finalize(const sycl::context &syclContext,
71-
const property_list &propList = {}) const;
71+
finalize(const property_list &propList = {}) const;
7272

7373
/// Change the state of a queue to be recording and associate this graph with
7474
/// it.
@@ -138,7 +138,6 @@ template <> class __SYCL_EXPORT command_graph<graph_state::executable> {
138138
void finalize_impl();
139139

140140
int MTag;
141-
const sycl::context &MCtx;
142141
std::shared_ptr<detail::exec_graph_impl> impl;
143142
};
144143
} // namespace experimental

sycl/source/detail/graph_impl.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,12 @@ void exec_graph_impl::find_real_deps(std::vector<pi_ext_sync_point> &Deps,
241241
}
242242
}
243243

244-
void exec_graph_impl::create_pi_command_buffers(sycl::device D,
245-
const sycl::context &Ctx) {
244+
void exec_graph_impl::create_pi_command_buffers(sycl::device D) {
246245
// TODO we only have a single command-buffer per graph here, but
247246
// this will need to be multiple command-buffers for non-trivial graphs
248247
pi_ext_command_buffer OutCommandBuffer;
249248
pi_ext_command_buffer_desc Desc{};
250-
auto ContextImpl = sycl::detail::getSyclObjImpl(Ctx);
249+
auto ContextImpl = sycl::detail::getSyclObjImpl(MContext);
251250
const sycl::detail::plugin &Plugin = ContextImpl->getPlugin();
252251
auto DeviceImpl = sycl::detail::getSyclObjImpl(D);
253252
pi_result Res =
@@ -284,13 +283,13 @@ void exec_graph_impl::create_pi_command_buffers(sycl::device D,
284283
Node->MKernelName);
285284
}
286285

287-
auto SetFunc = [&Plugin, &PiKernel, &Ctx](sycl::detail::ArgDesc &Arg,
286+
auto SetFunc = [&Plugin, &PiKernel, this](sycl::detail::ArgDesc &Arg,
288287
size_t NextTrueIndex) {
289288
sycl::detail::SetArgBasedOnType(
290289
Plugin, PiKernel,
291290
nullptr /* TODO: Handle spec constants and pass device image here */,
292-
nullptr /* TODO: Pass getMemAllocation function for buffers */, Ctx,
293-
false, Arg, NextTrueIndex);
291+
nullptr /* TODO: Pass getMemAllocation function for buffers */,
292+
this->MContext, false, Arg, NextTrueIndex);
294293
};
295294
std::vector<sycl::detail::ArgDesc> Args;
296295
sycl::detail::applyFuncOnFilteredArgs(EliminatedArgMask, Node->MArgs,
@@ -421,8 +420,9 @@ sycl::event exec_graph_impl::enqueue(
421420

422421
template <>
423422
command_graph<graph_state::modifiable>::command_graph(
423+
const sycl::context &syclContext, const sycl::device &syclDevice,
424424
const sycl::property_list &)
425-
: impl(std::make_shared<detail::graph_impl>()) {}
425+
: impl(std::make_shared<detail::graph_impl>(syclContext, syclDevice)) {}
426426

427427
template <>
428428
node command_graph<graph_state::modifiable>::add_impl(
@@ -465,8 +465,9 @@ void command_graph<graph_state::modifiable>::make_edge(node Sender,
465465
template <>
466466
command_graph<graph_state::executable>
467467
command_graph<graph_state::modifiable>::finalize(
468-
const sycl::context &CTX, const sycl::property_list &) const {
469-
return command_graph<graph_state::executable>{this->impl, CTX};
468+
const sycl::property_list &) const {
469+
return command_graph<graph_state::executable>{this->impl,
470+
this->impl->get_context()};
470471
}
471472

472473
template <>
@@ -531,7 +532,7 @@ bool command_graph<graph_state::modifiable>::end_recording(
531532

532533
command_graph<graph_state::executable>::command_graph(
533534
const std::shared_ptr<detail::graph_impl> &Graph, const sycl::context &Ctx)
534-
: MTag(rand()), MCtx(Ctx),
535+
: MTag(rand()),
535536
impl(std::make_shared<detail::exec_graph_impl>(Ctx, Graph)) {
536537
finalize_impl(); // Create backend representation for executable graph
537538
}
@@ -540,8 +541,8 @@ void command_graph<graph_state::executable>::finalize_impl() {
540541
// Create PI command-buffers for each device in the finalized context
541542
impl->schedule();
542543
#if SYCL_EXT_ONEAPI_GRAPH
543-
for (auto device : MCtx.get_devices()) {
544-
impl->create_pi_command_buffers(device, MCtx);
544+
for (auto device : impl->get_context().get_devices()) {
545+
impl->create_pi_command_buffers(device);
545546
}
546547
#endif
547548
}

sycl/source/detail/graph_impl.hpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,10 @@ struct node_impl {
127127
};
128128

129129
struct graph_impl {
130-
std::set<std::shared_ptr<node_impl>> MRoots;
131130

132-
std::shared_ptr<graph_impl> MParent;
131+
graph_impl(const sycl::context &syclContext, const sycl::device &syclDevice)
132+
: MContext(syclContext), MDevice(syclDevice), MRecordingQueues(),
133+
MEventsMap() {}
133134

134135
void add_root(const std::shared_ptr<node_impl> &);
135136
void remove_root(const std::shared_ptr<node_impl> &);
@@ -155,8 +156,6 @@ struct graph_impl {
155156
std::shared_ptr<node_impl>
156157
add(const std::vector<std::shared_ptr<node_impl>> &Dep = {});
157158

158-
graph_impl() = default;
159-
160159
/// Add a queue to the set of queues which are currently recording to this
161160
/// graph.
162161
void
@@ -199,8 +198,18 @@ struct graph_impl {
199198
/// an empty node is used to schedule dependencies on this sub graph.
200199
std::shared_ptr<node_impl>
201200
add_subgraph_nodes(const std::list<std::shared_ptr<node_impl>> &NodeList);
201+
sycl::context get_context() const { return MContext; }
202+
203+
std::set<std::shared_ptr<node_impl>> MRoots;
204+
std::shared_ptr<graph_impl> MParent;
202205

203206
private:
207+
// Context associated with this graph.
208+
sycl::context MContext;
209+
// Device associated with this graph. All graph nodes will execute on this
210+
// device.
211+
sycl::device MDevice;
212+
// Unique set of queues which are currently recording to this graph.
204213
std::set<std::shared_ptr<sycl::detail::queue_impl>> MRecordingQueues;
205214
// Map of events to their associated recorded nodes.
206215
std::unordered_map<std::shared_ptr<sycl::detail::event_impl>,
@@ -224,7 +233,9 @@ class exec_graph_impl {
224233
sycl::event exec(const std::shared_ptr<sycl::detail::queue_impl> &);
225234
/// Turns our internal graph representation into PI command-buffers for a
226235
/// device
227-
void create_pi_command_buffers(sycl::device D, const sycl::context &Ctx);
236+
void create_pi_command_buffers(sycl::device D);
237+
238+
sycl::context get_context() const { return MContext; }
228239

229240
const std::list<std::shared_ptr<node_impl>> &get_schedule() const {
230241
return MSchedule;

sycl/test/graph/graph-explicit-dotp-buffer.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ int main() {
2626

2727
sycl::queue q{sycl::gpu_selector_v};
2828

29-
sycl::ext::oneapi::experimental::command_graph g;
29+
sycl::ext::oneapi::experimental::command_graph g{q.get_context(),
30+
q.get_device()};
3031

3132
float dotpData = 0.f;
3233
std::vector<float> xData(n);
@@ -93,7 +94,7 @@ int main() {
9394
#endif
9495
});
9596

96-
auto executable_graph = g.finalize(q.get_context());
97+
auto executable_graph = g.finalize();
9798

9899
// Using shortcut for executing a graph of commands
99100
q.ext_oneapi_graph(executable_graph).wait();

sycl/test/graph/graph-explicit-dotp-device-mem.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ int main() {
2626

2727
sycl::queue q{sycl::gpu_selector_v};
2828

29-
sycl::ext::oneapi::experimental::command_graph g;
29+
sycl::ext::oneapi::experimental::command_graph g{q.get_context(),
30+
q.get_device()};
3031

3132
float *dotp = sycl::malloc_device<float>(1, q);
3233

@@ -83,7 +84,7 @@ int main() {
8384
},
8485
{node_a, node_b});
8586

86-
auto executable_graph = g.finalize(q.get_context());
87+
auto executable_graph = g.finalize();
8788

8889
// Using shortcut for executing a graph of commands
8990
q.ext_oneapi_graph(executable_graph).wait();

sycl/test/graph/graph-explicit-dotp.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ int main() {
2626

2727
sycl::queue q{sycl::gpu_selector_v};
2828

29-
sycl::ext::oneapi::experimental::command_graph g;
29+
sycl::ext::oneapi::experimental::command_graph g{q.get_context(),
30+
q.get_device()};
3031

3132
float *dotp = sycl::malloc_shared<float>(1, q);
3233

@@ -83,7 +84,7 @@ int main() {
8384
},
8485
{node_a, node_b});
8586

86-
auto executable_graph = g.finalize(q.get_context());
87+
auto executable_graph = g.finalize();
8788

8889
// Using shortcut for executing a graph of commands
8990
q.ext_oneapi_graph(executable_graph).wait();

sycl/test/graph/graph-explicit-empty.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ int main() {
77

88
sycl::queue q{sycl::gpu_selector_v};
99

10-
sycl::ext::oneapi::experimental::command_graph g;
10+
sycl::ext::oneapi::experimental::command_graph g{q.get_context(),
11+
q.get_device()};
1112

1213
const size_t n = 10;
1314
float *arr = sycl::malloc_device<float>(n, q);
@@ -35,7 +36,7 @@ int main() {
3536
},
3637
{empty2});
3738

38-
auto executable_graph = g.finalize(q.get_context());
39+
auto executable_graph = g.finalize();
3940

4041
q.submit([&](sycl::handler &h) {
4142
h.ext_oneapi_graph(executable_graph);

sycl/test/graph/graph-explicit-multiple-exec-graphs.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ int main() {
2626

2727
sycl::queue q{sycl::gpu_selector_v};
2828

29-
sycl::ext::oneapi::experimental::command_graph g;
29+
sycl::ext::oneapi::experimental::command_graph g{q.get_context(),
30+
q.get_device()};
3031

3132
float *dotp = sycl::malloc_shared<float>(1, q);
3233

@@ -73,14 +74,14 @@ int main() {
7374
},
7475
{node_a, node_b});
7576

76-
auto executable_graph = g.finalize(q.get_context());
77+
auto executable_graph = g.finalize();
7778

7879
// Add an extra node for the second executable graph which modifies the output
7980
auto node_d =
8081
g.add([&](sycl::handler &h) { h.single_task([=]() { dotp[0] += 1; }); },
8182
{node_c});
8283

83-
auto executable_graph_2 = g.finalize(q.get_context());
84+
auto executable_graph_2 = g.finalize();
8485

8586
// Using shortcut for executing a graph of commands
8687
q.ext_oneapi_graph(executable_graph).wait();

sycl/test/graph/graph-explicit-node-ordering.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ int main() {
77

88
sycl::queue q{sycl::gpu_selector_v};
99

10-
sycl::ext::oneapi::experimental::command_graph g;
10+
sycl::ext::oneapi::experimental::command_graph g{q.get_context(),
11+
q.get_device()};
1112

1213
const size_t n = 10;
1314
float *x = sycl::malloc_shared<float>(n, q);
@@ -36,7 +37,7 @@ int main() {
3637
g.make_edge(init, mult);
3738
g.make_edge(mult, add);
3839

39-
auto executable_graph = g.finalize(q.get_context());
40+
auto executable_graph = g.finalize();
4041

4142
q.submit([&](sycl::handler &h) {
4243
h.ext_oneapi_graph(executable_graph);

sycl/test/graph/graph-explicit-queue-shortcuts.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ int main() {
99

1010
// Test passing empty property list, which is the default
1111
sycl::property_list empty_properties;
12-
sycl::ext::oneapi::experimental::command_graph g(empty_properties);
12+
sycl::ext::oneapi::experimental::command_graph g(
13+
q.get_context(), q.get_device(), empty_properties);
1314

1415
const size_t n = 10;
1516
float *arr = sycl::malloc_shared<float>(n, q);
@@ -21,7 +22,7 @@ int main() {
2122
});
2223
});
2324

24-
auto executable_graph = g.finalize(q.get_context(), empty_properties);
25+
auto executable_graph = g.finalize(empty_properties);
2526

2627
auto e1 = q.ext_oneapi_graph(executable_graph);
2728
auto e2 = q.ext_oneapi_graph(executable_graph, e1);

0 commit comments

Comments
 (0)