Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][Graph] enable_shared_from_this refactor #15195

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 17 additions & 22 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ graph_impl::~graph_impl() {
}

std::shared_ptr<node_impl> graph_impl::addNodesToExits(
const std::shared_ptr<graph_impl> &Impl,
const std::list<std::shared_ptr<node_impl>> &NodeList) {
// Find all input and output nodes from the node list
std::vector<std::shared_ptr<node_impl>> Inputs;
Expand All @@ -327,18 +326,18 @@ std::shared_ptr<node_impl> graph_impl::addNodesToExits(
for (auto &NodeImpl : MNodeStorage) {
if (NodeImpl->MSuccessors.size() == 0) {
for (auto &Input : Inputs) {
NodeImpl->registerSuccessor(Input, NodeImpl);
NodeImpl->registerSuccessor(Input);
}
}
}

// Add all the new nodes to the node storage
for (auto &Node : NodeList) {
MNodeStorage.push_back(Node);
addEventForNode(Impl, std::make_shared<sycl::detail::event_impl>(), Node);
addEventForNode(std::make_shared<sycl::detail::event_impl>(), Node);
}

return this->add(Impl, Outputs);
return this->add(Outputs);
}

void graph_impl::addRoot(const std::shared_ptr<node_impl> &Root) {
Expand All @@ -350,8 +349,7 @@ void graph_impl::removeRoot(const std::shared_ptr<node_impl> &Root) {
}

std::shared_ptr<node_impl>
graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
const std::vector<std::shared_ptr<node_impl>> &Dep) {
graph_impl::add(const std::vector<std::shared_ptr<node_impl>> &Dep) {
// Copy deps so we can modify them
auto Deps = Dep;

Expand All @@ -361,17 +359,16 @@ graph_impl::add(const std::shared_ptr<graph_impl> &Impl,

addDepsToNode(NodeImpl, Deps);
// Add an event associated with this explicit node for mixed usage
addEventForNode(Impl, std::make_shared<sycl::detail::event_impl>(), NodeImpl);
addEventForNode(std::make_shared<sycl::detail::event_impl>(), NodeImpl);
return NodeImpl;
}

std::shared_ptr<node_impl>
graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
std::function<void(handler &)> CGF,
graph_impl::add(std::function<void(handler &)> CGF,
const std::vector<sycl::detail::ArgDesc> &Args,
const std::vector<std::shared_ptr<node_impl>> &Dep) {
(void)Args;
sycl::handler Handler{Impl};
sycl::handler Handler{shared_from_this()};
CGF(Handler);

if (Handler.getType() == sycl::detail::CGType::Barrier) {
Expand All @@ -394,7 +391,7 @@ graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
this->add(NodeType, std::move(Handler.impl->MGraphNodeCG), Dep);
NodeImpl->MNDRangeUsed = Handler.impl->MNDRangeUsed;
// Add an event associated with this explicit node for mixed usage
addEventForNode(Impl, std::make_shared<sycl::detail::event_impl>(), NodeImpl);
addEventForNode(std::make_shared<sycl::detail::event_impl>(), NodeImpl);

// Retrieve any dynamic parameters which have been registered in the CGF and
// register the actual nodes with them.
Expand All @@ -414,8 +411,7 @@ graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
}

std::shared_ptr<node_impl>
graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
const std::vector<sycl::detail::EventImplPtr> Events) {
graph_impl::add(const std::vector<sycl::detail::EventImplPtr> Events) {

std::vector<std::shared_ptr<node_impl>> Deps;

Expand All @@ -430,7 +426,7 @@ graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
}
}

return this->add(Impl, Deps);
return this->add(Deps);
}

std::shared_ptr<node_impl>
Expand Down Expand Up @@ -584,7 +580,7 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
}

// We need to add the edges first before checking for cycles
Src->registerSuccessor(Dest, Src);
Src->registerSuccessor(Dest);

// We can skip cycle checks if either Dest has no successors (cycle not
// possible) or cycle checks have been disabled with the no_cycle_check
Expand Down Expand Up @@ -1050,7 +1046,7 @@ void exec_graph_impl::duplicateNodes() {
// register those as successors with the current copied node
for (auto &NextNode : OriginalNode->MSuccessors) {
auto Successor = NodesMap.at(NextNode.lock());
NodeCopy->registerSuccessor(Successor, NodeCopy);
NodeCopy->registerSuccessor(Successor);
}
}

Expand Down Expand Up @@ -1092,7 +1088,7 @@ void exec_graph_impl::duplicateNodes() {

for (auto &NextNode : SubgraphNode->MSuccessors) {
auto Successor = SubgraphNodesMap.at(NextNode.lock());
NodeCopy->registerSuccessor(Successor, NodeCopy);
NodeCopy->registerSuccessor(Successor);
}
}

Expand Down Expand Up @@ -1126,7 +1122,7 @@ void exec_graph_impl::duplicateNodes() {
// Add all input nodes from the subgraph as successors for this node
// instead
for (auto &Input : Inputs) {
PredNode->registerSuccessor(Input, PredNode);
PredNode->registerSuccessor(Input);
}
}

Expand All @@ -1146,7 +1142,7 @@ void exec_graph_impl::duplicateNodes() {
// Add all Output nodes from the subgraph as predecessors for this node
// instead
for (auto &Output : Outputs) {
Output->registerSuccessor(SuccNode, Output);
Output->registerSuccessor(SuccNode);
}
}

Expand Down Expand Up @@ -1520,7 +1516,7 @@ node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
}

graph_impl::WriteLock Lock(impl->MMutex);
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(impl, DepImpls);
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(DepImpls);
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
}

Expand All @@ -1533,8 +1529,7 @@ node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
}

graph_impl::WriteLock Lock(impl->MMutex);
std::shared_ptr<detail::node_impl> NodeImpl =
impl->add(impl, CGF, {}, DepImpls);
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(CGF, {}, DepImpls);
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
}

Expand Down
43 changes: 14 additions & 29 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ inline node_type getNodeTypeFromCG(sycl::detail::CGType CGType) {
}

/// Implementation of node class from SYCL_EXT_ONEAPI_GRAPH.
class node_impl {
class node_impl : public std::enable_shared_from_this<node_impl> {
public:
using id_type = uint64_t;

Expand Down Expand Up @@ -112,20 +112,15 @@ class node_impl {

/// Add successor to the node.
/// @param Node Node to add as a successor.
/// @param Prev Predecessor to \p node being added as successor.
///
/// \p Prev should be a shared_ptr to an instance of this object, but can't
/// use a raw \p this pointer, so the extra \p Prev parameter is passed.
void registerSuccessor(const std::shared_ptr<node_impl> &Node,
const std::shared_ptr<node_impl> &Prev) {
void registerSuccessor(const std::shared_ptr<node_impl> &Node) {
if (std::find_if(MSuccessors.begin(), MSuccessors.end(),
[Node](const std::weak_ptr<node_impl> &Ptr) {
return Ptr.lock() == Node;
}) != MSuccessors.end()) {
return;
}
MSuccessors.push_back(Node);
Node->registerPredecessor(Prev);
Node->registerPredecessor(shared_from_this());
}

/// Add predecessor to the node.
Expand Down Expand Up @@ -161,9 +156,10 @@ class node_impl {
/// Construct a node from another node. This will perform a deep-copy of the
/// command group object associated with this node.
node_impl(node_impl &Other)
: MSuccessors(Other.MSuccessors), MPredecessors(Other.MPredecessors),
MCGType(Other.MCGType), MNodeType(Other.MNodeType),
MCommandGroup(Other.getCGCopy()), MSubGraphImpl(Other.MSubGraphImpl) {}
: enable_shared_from_this(Other), MSuccessors(Other.MSuccessors),
MPredecessors(Other.MPredecessors), MCGType(Other.MCGType),
MNodeType(Other.MNodeType), MCommandGroup(Other.getCGCopy()),
MSubGraphImpl(Other.MSubGraphImpl) {}

/// Copy-assignment operator. This will perform a deep-copy of the
/// command group object associated with this node.
Expand Down Expand Up @@ -901,32 +897,26 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
const std::vector<std::shared_ptr<node_impl>> &Dep = {});

/// Create a CGF node in the graph.
/// @param Impl Graph implementation pointer to create a handler with.
/// @param CGF Command-group function to create node with.
/// @param Args Node arguments.
/// @param Dep Dependencies of the created node.
/// @return Created node in the graph.
std::shared_ptr<node_impl>
add(const std::shared_ptr<graph_impl> &Impl,
std::function<void(handler &)> CGF,
add(std::function<void(handler &)> CGF,
const std::vector<sycl::detail::ArgDesc> &Args,
const std::vector<std::shared_ptr<node_impl>> &Dep = {});

/// Create an empty node in the graph.
/// @param Impl Graph implementation pointer.
/// @param Dep List of predecessor nodes.
/// @return Created node in the graph.
std::shared_ptr<node_impl>
add(const std::shared_ptr<graph_impl> &Impl,
const std::vector<std::shared_ptr<node_impl>> &Dep = {});
add(const std::vector<std::shared_ptr<node_impl>> &Dep = {});

/// Create an empty node in the graph.
/// @param Impl Graph implementation pointer.
/// @param Events List of events associated to this node.
/// @return Created node in the graph.
std::shared_ptr<node_impl>
add(const std::shared_ptr<graph_impl> &Impl,
const std::vector<sycl::detail::EventImplPtr> Events);
add(const std::vector<sycl::detail::EventImplPtr> Events);

/// Add a queue to the set of queues which are currently recording to this
/// graph.
Expand All @@ -951,15 +941,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
bool clearQueues();

/// Associate a sycl event with a node in the graph.
/// @param GraphImpl shared_ptr to Graph impl associated with this event, aka
/// this.
/// @param EventImpl Event to associate with a node in map.
/// @param NodeImpl Node to associate with event in map.
void addEventForNode(std::shared_ptr<graph_impl> GraphImpl,
std::shared_ptr<sycl::detail::event_impl> EventImpl,
void addEventForNode(std::shared_ptr<sycl::detail::event_impl> EventImpl,
std::shared_ptr<node_impl> NodeImpl) {
if (!(EventImpl->getCommandGraph()))
EventImpl->setCommandGraph(GraphImpl);
EventImpl->setCommandGraph(shared_from_this());
MEventsMap[EventImpl] = NodeImpl;
}

Expand Down Expand Up @@ -1238,12 +1225,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
void addRoot(const std::shared_ptr<node_impl> &Root);

/// Adds nodes to the exit nodes of this graph.
/// @param Impl Graph implementation pointer.
/// @param NodeList List of nodes from sub-graph in schedule order.
/// @return An empty node is used to schedule dependencies on this sub-graph.
std::shared_ptr<node_impl>
addNodesToExits(const std::shared_ptr<graph_impl> &Impl,
const std::list<std::shared_ptr<node_impl>> &NodeList);
addNodesToExits(const std::list<std::shared_ptr<node_impl>> &NodeList);

/// Adds dependencies for a new node, if it has no deps it will be
/// added as a root node.
Expand All @@ -1253,7 +1238,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
const std::vector<std::shared_ptr<node_impl>> &Deps) {
if (!Deps.empty()) {
for (auto &N : Deps) {
N->registerSuccessor(Node, N);
N->registerSuccessor(Node);
this->removeRoot(Node);
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ event handler::finalize() {
}

// Associate an event with this new node and return the event.
GraphImpl->addEventForNode(GraphImpl, EventImpl, NodeImpl);
GraphImpl->addEventForNode(EventImpl, NodeImpl);

NodeImpl->MNDRangeUsed = impl->MNDRangeUsed;

Expand Down
Loading