Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync the whole Meg-LM fused_kernels sub-tree #260

Merged
merged 3 commits into from
Mar 7, 2022
Merged

sync the whole Meg-LM fused_kernels sub-tree #260

merged 3 commits into from
Mar 7, 2022

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Mar 1, 2022

As flagged by @thomasw21 in #259 - in we have only synced part of the fused_kernels fixes applied to Megatron-LM here #151.

I tried to track all the changes since then, but there are too many and often are mixed with other unrelated PRs, so how about we just sync the whole folder and other related files.

this PR is trying just that.

I have no idea how to track all the individual contributors across many PRs, but I think it was primarily @hyunwoongko so it should be easy to push him in as a contributor:

git commit --author "hyunwoongko <[email protected]>" -am "author attribution" --allow-empty

and it will be so once this is squash-merged.

@stas00 stas00 mentioned this pull request Mar 1, 2022
@stas00
Copy link
Contributor Author

stas00 commented Mar 2, 2022

this is not good, the performance is worse and then it OOMed after 4 iterations:


before fused kernel fixes:

 iteration        2/   95367 | consumed samples:         4096 | consumed tokens:      8388608 | elapsed time per iteration (s): 152.10 | learning rate: 3.787E-06 | global batch size:  2048 | lm loss: 6.353651E+01 | grad norm: 21.493 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 | samples per second: 13.465 | TFLOPs: 141.12 |

after fused kernel fixes:

 iteration        2/   95367 | consumed samples:         4096 | consumed tokens:      8388608 | elapsed time per iteration (s): 159.85 | learning rate: 3.787E-06 | global batch size:  2048 | lm loss: 6.353651E+01 | grad norm: 21.493 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 | samples per second: 12.812 | TFLOPs: 134.27 |

it's possible I missed some other changes outside of these folders.

Probably need to do it properly and track and replay each change

and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
if 0 <= sk <= 4096:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That test becomes useless no? unless we need to test sq now

Copy link
Contributor Author

@stas00 stas00 Mar 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have just copied the code from Megatron-LM verbatim and only re-added back any changes we added.

i.e. I haven't added any code of my own. I only changed the outdated comment to match 4096

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to use a smaller model and restart the testing, as it was on brink of OOM already. So it's very likely the issues I'm seeing are unrelated.

I will report back when I get new numbers.

@stas00
Copy link
Contributor Author

stas00 commented Mar 2, 2022

OK, I was testing with a broken merge of the deepspeed branch which introduced a memory leak. Found the issue now and will re-test anew with this and your PRs.

@stas00
Copy link
Contributor Author

stas00 commented Mar 2, 2022

OK, after finding an issue elsewhere this PR works just fine. Except it makes no difference whatsoever to the outcome. Perhaps we aren't impacted since we don't hit the constraints that weren't done well originally. I haven't investigated.

But the numbers are telling:

before fused kernel fixes:

 iteration        2/   95367 | consumed samples:         4096 | consumed tokens:      8388608 | elapsed time per iteration (s): 135.32 | learning rate: 3.787E-06 | global batch size:  2048 | lm loss: 6.354185E+01 | grad norm: 19.988 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 | samples per second: 15.134 | TFLOPs: 139.05 |

mem: 59GB

after fused kernel fixes (this PR):

 iteration        2/   95367 | consumed samples:         4096 | consumed tokens:      8388608 | elapsed time per iteration (s): 134.96 | learning rate: 3.787E-06 | global batch size:  2048 | lm loss: 6.354185E+01 | grad norm: 19.988 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 | samples per second: 15.175 | TFLOPs: 139.42 |

mem: 59GB

the small fluctuation is fine - they are identical throughputs.

@stas00 stas00 merged commit 1cb76a6 into main Mar 7, 2022
@stas00 stas00 deleted the sync-meg-lm branch March 7, 2022 02:47
@stas00
Copy link
Contributor Author

stas00 commented Mar 7, 2022

Tested it some more w/ and w/o this change and they track very close

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants