Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed optimizer infrastructure for FP8 parameters #1723

Merged
merged 2 commits into from
Sep 29, 2023

Conversation

timmoon10
Copy link
Contributor

This PR does some refactoring that will enable distributed optimizer support for FP8 parameters in NeMo. It adds the option to do parameter all-gathers in integer dtypes and adds two member functions - _check_params_shard_dtypes and _param_copy_fragments - to handle casting into and out of the all-gather buffer. For now these functions will either do a direct cast for floating-point dtypes or copy the most significant bytes for other dtypes. I plan to override these functions in the NeMo derived class so that it casts to FP8, performs the all-gather in UINT8, and unpacks into a custom FP8 tensor class.

This PR depends on #1719 and #1721.

@crcrpar crcrpar merged commit 2386a91 into NVIDIA:master Sep 29, 2023
minitu pushed a commit to minitu/apex that referenced this pull request Sep 29, 2023
* Add distopt support for param syncs with non-floating-point dtypes

Signed-off-by: Tim Moon <[email protected]>

* Update apex/contrib/optimizers/distributed_fused_adam.py

Co-authored-by: Masaki Kozuki <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Masaki Kozuki <[email protected]>
crcrpar added a commit that referenced this pull request Sep 30, 2023
* Add update_scale_hysteresis

* Fix compile errors

* Massively reduce LayerNorm/RMSNorm GPU memory usage in modern networks by tricking torch autograd (#1715)

* input grad checks out

* adding clamp gamma

* Both old and proposed implementation checks out

* 2 tests not yet passed due to numerical issues

* mem_eff works

* fast-layer-norm done

* Moving mem-eff to templates

* Relax tolerance for memory efficient backward

* Fix backward api of python

* Distributed optimizer infrastructure for FP8 parameters (#1723)

* Add distopt support for param syncs with non-floating-point dtypes

Signed-off-by: Tim Moon <[email protected]>

* Update apex/contrib/optimizers/distributed_fused_adam.py

Co-authored-by: Masaki Kozuki <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Masaki Kozuki <[email protected]>

* Add unit test

* Fix comment in unit test

* Remove unnecessary bits

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Jaemin Choi <[email protected]>
Co-authored-by: Rui Wang <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Masaki Kozuki <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants