Skip to content

Commit df971e5

Browse files
authored
[SYCL] Minor graph classes refactor (#36)
- getSyclObjImpl and createSyclObjFromImpl support added - Minor renaming to enable this. - Adds basic results validation to dotp test - Minor fixes to address warnings etc.
1 parent f71ea49 commit df971e5

File tree

2 files changed

+73
-40
lines changed

2 files changed

+73
-40
lines changed

sycl/include/sycl/ext/oneapi/experimental/graph.hpp

Lines changed: 66 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -130,22 +130,45 @@ struct graph_impl {
130130
MSchedule.clear();
131131
}
132132

133+
template <typename T>
134+
node_ptr add(graph_ptr impl, T cgf, const std::vector<node_ptr> &dep = {}) {
135+
node_ptr nodeImpl = std::make_shared<node_impl>(impl, cgf);
136+
if (!dep.empty()) {
137+
for (auto n : dep) {
138+
n->register_successor(nodeImpl); // register successor
139+
this->remove_root(nodeImpl); // remove receiver from root node
140+
// list
141+
}
142+
} else {
143+
this->add_root(nodeImpl);
144+
}
145+
return nodeImpl;
146+
}
147+
133148
graph_impl() : MFirst(true) {}
134149
};
135150

136151
} // namespace detail
137152

138-
struct node {
139-
detail::node_ptr MNode;
140-
detail::graph_ptr MGraph;
141-
153+
class node {
154+
public:
142155
template <typename T>
143156
node(detail::graph_ptr g, T cgf)
144-
: MGraph(g), MNode(new detail::node_impl(g, cgf)){};
145-
void register_successor(node n) { MNode->register_successor(n.MNode); }
146-
void exec(sycl::queue q, sycl::event = sycl::event()) { MNode->exec(q); }
157+
: MGraph(g), impl(new detail::node_impl(g, cgf)) {}
158+
void register_successor(node n) { impl->register_successor(n.impl); }
159+
void exec(sycl::queue q) { impl->exec(q); }
160+
161+
private:
162+
node(detail::node_ptr Impl) : impl(Impl) {}
163+
164+
template <class Obj>
165+
friend decltype(Obj::impl)
166+
sycl::detail::getSyclObjImpl(const Obj &SyclObject);
167+
template <class T>
168+
friend T sycl::detail::createSyclObjFromImpl(decltype(T::impl) ImplObj);
147169

148-
void set_root() { MGraph->add_root(MNode); }
170+
detail::node_ptr impl;
171+
detail::graph_ptr MGraph;
149172
};
150173

151174
template <graph_state State = graph_state::modifiable> class command_graph {
@@ -165,60 +188,68 @@ template <graph_state State = graph_state::modifiable> class command_graph {
165188
command_graph<graph_state::executable>
166189
finalize(const sycl::context &syclContext) const;
167190

168-
command_graph() : MGraph(new detail::graph_impl()) {}
191+
command_graph() : impl(new detail::graph_impl()) {}
169192

170193
private:
171-
detail::graph_ptr MGraph;
194+
command_graph(detail::graph_ptr Impl) : impl(Impl) {}
195+
196+
template <class Obj>
197+
friend decltype(Obj::impl)
198+
sycl::detail::getSyclObjImpl(const Obj &SyclObject);
199+
template <class T>
200+
friend T sycl::detail::createSyclObjFromImpl(decltype(T::impl) ImplObj);
201+
202+
detail::graph_ptr impl;
172203
};
173204

174205
template <> class command_graph<graph_state::executable> {
175206
public:
176-
int MTag;
177-
const sycl::context &MCtx;
178-
179207
void exec_and_wait(sycl::queue q);
180208

181209
command_graph() = delete;
182210

183211
command_graph(detail::graph_ptr g, const sycl::context &ctx)
184-
: MGraph(g), MCtx(ctx), MTag(rand()) {}
212+
: MTag(rand()), MCtx(ctx), impl(g) {}
185213

186214
private:
187-
detail::graph_ptr MGraph;
215+
int MTag;
216+
const sycl::context &MCtx;
217+
detail::graph_ptr impl;
188218
};
189219

190220
template <>
191221
template <typename T>
192-
node command_graph<graph_state::modifiable>::add(T cgf,
193-
const std::vector<node> &dep) {
194-
node ret_val(MGraph, cgf);
195-
if (!dep.empty()) {
196-
for (auto n : dep)
197-
this->make_edge(n, ret_val);
198-
} else {
199-
ret_val.set_root();
222+
inline node
223+
command_graph<graph_state::modifiable>::add(T cgf,
224+
const std::vector<node> &dep) {
225+
std::vector<detail::node_ptr> depImpls;
226+
for (auto &d : dep) {
227+
depImpls.push_back(sycl::detail::getSyclObjImpl(d));
200228
}
201-
return ret_val;
229+
230+
auto nodeImpl = impl->add(impl, cgf, depImpls);
231+
return sycl::detail::createSyclObjFromImpl<node>(nodeImpl);
202232
}
203233

204234
template <>
205-
void command_graph<graph_state::modifiable>::make_edge(node sender,
206-
node receiver) {
235+
inline void command_graph<graph_state::modifiable>::make_edge(node sender,
236+
node receiver) {
207237
sender.register_successor(receiver); // register successor
208-
MGraph->remove_root(receiver.MNode); // remove receiver from root node
209-
// list
238+
impl->remove_root(
239+
sycl::detail::getSyclObjImpl(receiver)); // remove receiver from root node
240+
// list
210241
}
211242

212243
template <>
213-
command_graph<graph_state::executable>
214-
command_graph<graph_state::modifiable>::finalize(
215-
const sycl::context &ctx) const {
216-
return command_graph<graph_state::executable>{this->MGraph, ctx};
244+
command_graph<graph_state::executable> inline command_graph<
245+
graph_state::modifiable>::finalize(const sycl::context &ctx) const {
246+
return command_graph<graph_state::executable>{this->impl, ctx};
217247
}
218248

219-
void command_graph<graph_state::executable>::exec_and_wait(sycl::queue q) {
220-
MGraph->exec_and_wait(q);
221-
};
249+
inline void
250+
command_graph<graph_state::executable>::exec_and_wait(sycl::queue q) {
251+
impl->exec_and_wait(q);
252+
}
222253

223254
} // namespace experimental
224255
} // namespace oneapi

sycl/test/graph/graph-explicit-dotp.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,8 @@ int main() {
3232
sycl::ext::oneapi::property::queue::lazy_execution{}
3333
};
3434

35-
sycl::gpu_selector device_selector;
36-
37-
sycl::queue q{device_selector, properties};
38-
35+
sycl::queue q{sycl::gpu_selector_v, properties};
36+
3937
sycl::ext::oneapi::experimental::command_graph g;
4038

4139
float *dotp = sycl::malloc_shared<float>(1, q);
@@ -80,7 +78,11 @@ int main() {
8078
auto exec_graph = g.finalize(q.get_context());
8179

8280
exec_graph.exec_and_wait(q);
83-
81+
82+
if (*dotp != host_gold_result()) {
83+
std::cout << "Error unexpected result!\n";
84+
}
85+
8486
sycl::free(dotp, q);
8587
sycl::free(x, q);
8688
sycl::free(y, q);

0 commit comments

Comments
 (0)