-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify gating check for CUDA Graph usage #16491
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some time, it could miss some case that shall not run cuda graph (like some node does not have CUDA implementation and consuming some shape input but its output is not constant.
As we have warning message, it is fine that we apply less constraints to unblock most users (assuming that user will verify the accuracy by themselves).
I think I will simplify this logic in such a way that we allow for CUDA Graph capture as long as there are no memcpy nodes and we will log a warning for the user to check results if we see a Shape node and we have some nodes assigned to the CPU EP (assume that these nodes are shape subgraphs). |
f49e546
Description
As part of relaxing the node EP check for CUDA Graphs in #16358, logic was introduced to collect all shape massaging nodes. This logic was to collect all nodes between
Shape
andReshape
nodes. This covers the most common shape massaging node pattern. However, this isn't exhaustive. Shape massaging subgraphs may not end at aReshape
node. It may end in other nodes that consume shape info (likeExpand
,ConstantOfShape
, etc.) In fact, aReshape
node itself may be part of the all the shape massaging nodes (see illustration below).The gating check now is as follows:
(1) For CUDA and TRT EP: Ensure that there are no control flow nodes (same as before)
(2)
For TRP EP: Ensure all nodes have been placed on the TRT EP (same as before)
(New logic below)
For CUDA EP: Ensure that all nodes have been partitioned to CUDA or CPU EP && there are no memcpy nodes. The reasoning behind this logic is that certain shape nodes will be forced onto CPU and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP.
Additionally, for the CUDA EP, we log a warning for the user to know that there are shape subgraphs that will execute on CPU for them to decide if they want to use CUDA Graphs. In most cases, shape subgraphs on CPU should mean it is safe to use CUDA Graphs.
Motivation and Context
Refine logic introduced in #16358