-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Use RTC for elementwise and broadcast ops #18622
Conversation
Remove elemwise_scatter_op.* Fix BinaryScalar usage in NumPy
Reorganization
Add RTC to NumPy ops
Hey @ptrendx , Thanks for submitting the PR
CI supported jobs: [unix-cpu, windows-gpu, centos-cpu, unix-gpu, windows-cpu, edge, miscellaneous, website, centos-gpu, sanity, clang] Note: |
Thank you @ptrendx! As this makes nvrtc feature mandatory, it may be necessary to first/prior to next release also fix #17858? It seems that there are a number of users that try to open GPU builds on machines without libcuda.so and this is broken since 1.6 due to #17858 (but currently there is the workaround of disabling nvrtc) |
Yeah, we will need to sort it out before the release - I actually thought that if we manage to get everything via RTC (which seems daunting though), we could actually dynamically load both libcuda and libnvrtc and have a single build that supports everything instead of mxnet-cu*. That said, RTC for everything is a big task and I would definitely need help from community if we were to pull it off. |
Also, I thought that you made the compilation use C++17 (so I use |
The reason is that CUDA does not support C++17 prior to CUDA11. Thus cuda files are compiled with C++14. We can consider requiring CUDA11 for MXNet 2 |
Fixes for mixed type gradient functions Set the launch bounds on RTC kernels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked at the performance impact of RTC and it adds ~2us of CPU time to the launch, mostly due to string manipulation.
If we use cuda graph to cache all the kernels to launch, would this overhead be mitigated?
@eric-haibin-lin Yes. The overhead comes from preparing a string with kernel options (like the datatypes) and searching for the kernel function in cache. CUDA graph caches the resulting function so the lookup does not occur anymore. That said, this overhead is lower than the overhead of |
@mxnet-bot run ci [unix-cpu] |
Jenkins CI successfully triggered : [unix-cpu] |
@@ -47,6 +47,12 @@ The following tutorials will help you learn how to customize MXNet. | |||
How to create new MXNet operators in MXNet's backend using C++. | |||
An example custom quadratic function op. | |||
|
|||
.. card:: | |||
:title: Using runtime compilation (RTC) to write CUDA kernels in MXNet | |||
:link: https://mxnet.apache.org/api/faq/using_rtc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use :link: /api/faq/using_rtc
instead as the documentation is versioned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok - I was just copy-pasting from the other sections there (like add_op_in_backend). Will update those as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, the toctree also has the full links - will /api/faq/...
approach work there too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I tried putting those relative links in the toctree as well and building of Python docs complains with
/work/mxnet/docs/python_docs/python/build/tutorials/extend/index.rst:57: WARNING: toctree contains reference to nonexisting document 'api/faq/new_op'
/work/mxnet/docs/python_docs/python/build/tutorials/extend/index.rst:57: WARNING: toctree contains reference to nonexisting document 'api/faq/add_op_in_backend'
/work/mxnet/docs/python_docs/python/build/tutorials/extend/index.rst:57: WARNING: toctree contains reference to nonexisting document 'api/faq/using_rtc'
and those entries are not shown in the final website.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm currently fixing a number of issues in #18839. You may get conflict from this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice :-). I will push the changes to address your other comments then and will get back to the website part once your PR is merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this PR, the training of Electra Model in gluon-nlp will raise error below
Setting
@ptrendx Could you please take a look? |
Yes, I will look into it. |
@ZiyueHuang Please try with PR #18984. |
@ptrendx It works. Thanks for the fix. |
Description
As described in #18280 (comment), MXNet currently contains too many CUDA kernels, that affect negatively compile time, size of the resulting binary (resulting in issues like #17045 and #18205), and GPU memory consumption (as all of those kernels need to be loaded during the first GPU context creation to GPU memory).
The reason of those problems is the number of templates that need to be instantiated, especially in the case of NumPy operators which need to accept different input/output types - this results in multiple nested
MSHADOW_TYPE_SWITCH
macros and great increase in the number of kernels generated, most of them pretty much never used. For example, executing this command:on the nightly build of mxnet-cu102 from 6/25 shows 69169 kernels (the same command executed on the library built with this PR at the time of writing gives 51511 kernels).
The proposed approach is to use RTC (runtime compilation) in order to generate the needed kernels at runtime. This saves the ahead-of-time compilation time and binary size as well as the GPU memory utilization (since only the needed kernels are generated, not all combinations).
To test the impact on binary size and memory consumption reduction I compiled MXNet for 5 GPU architectures (5.2, 6.0, 6.1, 7.0, 7.5) using CUDA 11 both from the head of this PR and from the latest master commit included (f872b43).
Binary size reduction: 292 MB (from 2 GB to 1.7 GB)
Idle GPU memory consumption reduction: 96 MB (from 1442 MB to 1346 MB)
Idle GPU memory consumption reduction was checked by launching Python interpreter and checking GPU memory consumption after calling:
This PR uses that approach to handle elementwise and broadcast kernels (as well as their backward), which constitute a big portion of the total number of kernels in MXNet.
FYI @leezu @sxjscience @eric-haibin-lin
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments
Things left to do:
MixedUnaryBackward
functionsAfter this PR the next step would be to use the same approach for reduce kernels - this PR already contains a ground work for this as reduction was needed for backward of broadcast ops, but it does not apply that path to standalone reduction ops. Grepping for
reduce_kernel
in the symbols visible in libmxnet.so after application of this PR:gives 12057 entries. This would also help with reducing the amount of code duplication that this PR introduces (to maintain both RTC and non-RTC paths).