Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify gating check for CUDA Graph usage #16491

Merged
merged 4 commits into from
Jun 28, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,17 @@ inline std::basic_string<T> GetCurrentTimeString() {

#if !defined(ORT_MINIMAL_BUILD)

static bool IsNodeOnCpuEP(const Node& node) {
const auto& node_provider = node.GetExecutionProviderType();

// Empty EP string means CPU EP
return node_provider.empty() || node_provider == onnxruntime::kCpuExecutionProvider;
}

/* This method returns ture if all *compute* nodes are placed on the CUDA EP
and all shape nodes are placed on the CPU EP
*/
std::pair<bool, int> AreAllComputeNodesAssignedToCudaEp(const Graph& graph) {
static std::pair<bool, int> AreAllComputeNodesAssignedToCudaEp(const Graph& graph) {
InlinedHashSet<NodeIndex> shape_nodes;
InlinedHashSet<NodeIndex> bfs_visited;
std::queue<NodeIndex> bfs_queue;
Expand All @@ -143,9 +150,9 @@ std::pair<bool, int> AreAllComputeNodesAssignedToCudaEp(const Graph& graph) {
shape_nodes.insert(node_index);

for (auto iter = node->OutputNodesBegin(), end = node->OutputNodesEnd(); iter != end; ++iter) {
// If the child is not a Reshape node and we haven't processed/visited the node already,
// If the child is on CPU and we haven't processed/visited the node already,
// add the node for further processing
if (iter->OpType() != "Reshape" && (bfs_visited.find(iter->Index()) == bfs_visited.end())) {
if (IsNodeOnCpuEP(*iter) && (bfs_visited.find(iter->Index()) == bfs_visited.end())) {
bfs_visited.insert(iter->Index());
bfs_queue.push(iter->Index());
}
Expand Down Expand Up @@ -177,7 +184,7 @@ std::pair<bool, int> AreAllComputeNodesAssignedToCudaEp(const Graph& graph) {
return std::make_pair(true, static_cast<int>(shape_nodes.size()));
}

bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) {
static bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) {
for (const auto& node : graph.Nodes()) {
const auto& node_provider = node.GetExecutionProviderType();

Expand All @@ -189,7 +196,7 @@ bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType prov
return true;
}

bool HasControlflowNodes(const Graph& graph) {
static bool HasControlflowNodes(const Graph& graph) {
for (const auto& node : graph.Nodes()) {
if (node.ContainsSubgraph()) {
return true;
Expand Down