Skip to content

Commit e64923e

Browse files
[NFC][SYCL][Graph] Introduce nodes_range utility
It looks like lots of `std::vector<std::weak_ptr<node_impl>>` aren't actually used for lifetime management and aren't expected to ever have an "expired" `std::weak_ptr`. As such, it would make sense to transition many of them to raw pointers/references (e.g. `std::vector<node_impl *>`). This utility should help to make such transition a bit easier. However, I expect it to be useful even after the elimination of unnecessary `weak_ptr`s and that can be seen even as part of this PR already. * Scenario A: we need to process both `std::vector<node_impl *>` passed around as function parameters or created as a local variable on stack, and `std::vector<node>` or `std::vector<std::shared_ptr<node_impl>>` that *is* used for lifetime management (e.g. `handler`'s data member, or all the nodes in the graph). This utility would allow such an API to have a single overload to work with both scenarios. * Scenario B: no conversion between `std::set`->`std::vector<>` (already updated here) is necessary and no templates required to support them both via single overload.
1 parent e98d8a0 commit e64923e

File tree

3 files changed

+98
-24
lines changed

3 files changed

+98
-24
lines changed

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ void graph_impl::markCGMemObjs(
453453
}
454454

455455
std::shared_ptr<node_impl>
456-
graph_impl::add(std::vector<std::shared_ptr<node_impl>> &Deps) {
456+
graph_impl::add(nodes_range Deps) {
457457
const std::shared_ptr<node_impl> &NodeImpl = std::make_shared<node_impl>();
458458

459459
MNodeStorage.push_back(NodeImpl);
@@ -542,12 +542,12 @@ graph_impl::add(std::function<void(handler &)> CGF,
542542
std::shared_ptr<node_impl>
543543
graph_impl::add(const std::vector<sycl::detail::EventImplPtr> Events) {
544544

545-
std::vector<std::shared_ptr<node_impl>> Deps;
545+
std::vector<node_impl *> Deps;
546546

547547
// Add any nodes specified by event dependencies into the dependency list
548548
for (const auto &Dep : Events) {
549549
if (auto NodeImpl = MEventsMap.find(Dep); NodeImpl != MEventsMap.end()) {
550-
Deps.push_back(NodeImpl->second);
550+
Deps.push_back(NodeImpl->second.get());
551551
} else {
552552
throw sycl::exception(sycl::make_error_code(errc::invalid),
553553
"Event dependency from handler::depends_on does "
@@ -561,23 +561,22 @@ graph_impl::add(const std::vector<sycl::detail::EventImplPtr> Events) {
561561
std::shared_ptr<node_impl>
562562
graph_impl::add(node_type NodeType,
563563
std::shared_ptr<sycl::detail::CG> CommandGroup,
564-
std::vector<std::shared_ptr<node_impl>> &Deps) {
564+
nodes_range Deps) {
565565

566566
// A unique set of dependencies obtained by checking requirements and events
567567
std::set<std::shared_ptr<node_impl>> UniqueDeps = getCGEdges(CommandGroup);
568568

569569
// Track and mark the memory objects being used by the graph.
570570
markCGMemObjs(CommandGroup);
571571

572-
// Add any deps determined from requirements and events into the dependency
573-
// list
574-
Deps.insert(Deps.end(), UniqueDeps.begin(), UniqueDeps.end());
575-
576572
const std::shared_ptr<node_impl> &NodeImpl =
577573
std::make_shared<node_impl>(NodeType, std::move(CommandGroup));
578574
MNodeStorage.push_back(NodeImpl);
579575

576+
// Add any deps determined from requirements and events into the dependency
577+
// list
580578
addDepsToNode(NodeImpl, Deps);
579+
addDepsToNode(NodeImpl, UniqueDeps);
581580

582581
if (NodeType == node_type::async_free) {
583582
auto AsyncFreeCG =
@@ -592,7 +591,7 @@ graph_impl::add(node_type NodeType,
592591

593592
std::shared_ptr<node_impl>
594593
graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
595-
std::vector<std::shared_ptr<detail::node_impl>> &Deps) {
594+
nodes_range Deps) {
596595
// Set of Dependent nodes based on CG event and accessor dependencies.
597596
std::set<std::shared_ptr<node_impl>> DynCGDeps =
598597
getCGEdges(DynCGImpl->MCommandGroups[0]);

sycl/source/detail/graph/graph_impl.hpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
147147
/// @return Created node in the graph.
148148
std::shared_ptr<node_impl> add(node_type NodeType,
149149
std::shared_ptr<sycl::detail::CG> CommandGroup,
150-
std::vector<std::shared_ptr<node_impl>> &Deps);
150+
nodes_range Deps);
151151

152152
/// Create a CGF node in the graph.
153153
/// @param CGF Command-group function to create node with.
@@ -161,7 +161,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
161161
/// Create an empty node in the graph.
162162
/// @param Deps List of predecessor nodes.
163163
/// @return Created node in the graph.
164-
std::shared_ptr<node_impl> add(std::vector<std::shared_ptr<node_impl>> &Deps);
164+
std::shared_ptr<node_impl> add(nodes_range Deps);
165165

166166
/// Create an empty node in the graph.
167167
/// @param Events List of events associated to this node.
@@ -174,8 +174,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
174174
/// @param Deps List of predecessor nodes.
175175
/// @return Created node in the graph.
176176
std::shared_ptr<node_impl>
177-
add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
178-
std::vector<std::shared_ptr<node_impl>> &Deps);
177+
add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl, nodes_range Deps);
179178

180179
/// Add a queue to the set of queues which are currently recording to this
181180
/// graph.
@@ -543,13 +542,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
543542
/// @param Node The node to add deps for
544543
/// @param Deps List of dependent nodes
545544
void addDepsToNode(const std::shared_ptr<node_impl> &Node,
546-
std::vector<std::shared_ptr<node_impl>> &Deps) {
547-
if (!Deps.empty()) {
548-
for (auto &N : Deps) {
549-
N->registerSuccessor(Node);
550-
this->removeRoot(Node);
551-
}
552-
} else {
545+
nodes_range Deps) {
546+
for (node_impl &N : Deps) {
547+
N.registerSuccessor(Node);
548+
this->removeRoot(Node);
549+
}
550+
if (Node->MPredecessors.empty()) {
553551
this->addRoot(Node);
554552
}
555553
}

sycl/source/detail/graph/node_impl.hpp

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
#include <sycl/detail/cg_types.hpp> // for CGType
1515
#include <sycl/detail/kernel_desc.hpp> // for kernel_param_kind_t
1616

17-
#include <cstring> // for memcpy
18-
#include <fstream> // for fstream, ostream
19-
#include <iomanip> // for setw, setfill
20-
#include <vector> // for vector
17+
#include <cstring>
18+
#include <fstream>
19+
#include <iomanip>
20+
#include <set>
21+
#include <vector>
2122

2223
namespace sycl {
2324
inline namespace _V1 {
@@ -753,6 +754,82 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
753754
return std::make_unique<CGT>(*static_cast<CGT *>(MCommandGroup.get()));
754755
}
755756
};
757+
758+
// Non-owning!
759+
class nodes_range {
760+
template <typename... Containers>
761+
using storage_iter_impl =
762+
std::variant<typename Containers::const_iterator...>;
763+
764+
using storage_iter = storage_iter_impl<
765+
std::vector<std::shared_ptr<node_impl>>, std::vector<node_impl *>,
766+
// Next one is temporary. It looks like `weak_ptr`s aren't
767+
// used for the actual lifetime management and the objects are
768+
// always guaranteed to be alive. Once the code is cleaned
769+
// from `weak_ptr`s this alternative should be removed too.
770+
std::vector<std::weak_ptr<node_impl>>,
771+
//
772+
std::set<std::shared_ptr<node_impl>>>;
773+
774+
storage_iter Begin;
775+
storage_iter End;
776+
const size_t Size;
777+
778+
public:
779+
nodes_range(const nodes_range &Other) = default;
780+
781+
template <
782+
typename ContainerTy,
783+
typename = std::enable_if_t<!std::is_same_v<nodes_range, ContainerTy>>>
784+
nodes_range(ContainerTy &Container)
785+
: Begin{Container.begin()}, End{Container.end()}, Size{Container.size()} {
786+
}
787+
788+
class iterator {
789+
storage_iter It;
790+
791+
iterator(storage_iter It) : It(It) {}
792+
friend class nodes_range;
793+
794+
public:
795+
iterator &operator++() {
796+
It = std::visit(
797+
[](auto &&It) {
798+
++It;
799+
return storage_iter{It};
800+
},
801+
It);
802+
return *this;
803+
}
804+
bool operator!=(const iterator &Other) const { return It != Other.It; }
805+
806+
node_impl &operator*() {
807+
return std::visit(
808+
[](auto &&It) -> node_impl & {
809+
auto &Elem = *It;
810+
if constexpr (std::is_same_v<std::decay_t<decltype(Elem)>,
811+
std::weak_ptr<node_impl>>) {
812+
// This assumes that weak_ptr doesn't actually manage lifetime and
813+
// the object is guaranteed to be alive (which seems to be the
814+
// assumption across all graph code).
815+
return *Elem.lock();
816+
} else {
817+
return *Elem;
818+
}
819+
},
820+
It);
821+
}
822+
};
823+
824+
iterator begin() const {
825+
return {std::visit([](auto &&It) { return storage_iter{It}; }, Begin)};
826+
}
827+
iterator end() const {
828+
return {std::visit([](auto &&It) { return storage_iter{It}; }, End)};
829+
}
830+
size_t size() const { return Size; }
831+
bool empty() const { return Size == 0; }
832+
};
756833
} // namespace detail
757834
} // namespace experimental
758835
} // namespace oneapi

0 commit comments

Comments
 (0)