-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow streaming (datasets.IterableDataset) #1468
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @BramVanroy for the detailed work on this ! I think there should be no harm supporting this in SFTTrainer, can you run the styling checks? make precommit
- then we can merge imo
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Do you have any thoughts on the incorrect reporting of number of epochs? |
@BramVanroy hmmm not sure what we can do here to be honest :/ |
Hm yeah, maybe let's keep it like that for now then. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again !
@younesbelkada I probably should have included it in this PR, too, but there is another useful (imo) change that is just three lines of code that could make the SFTTrainer even more flexible. I added it to a separate PR: #1520 |
* safe-guard iterabledatasets * import datasets * reference the correct IterableDataset * make pre-commit
@BramVanroy @younesbelkada Hello guys, amazing work on this repo. |
The motivation for this PR is given in this issue: #1455 (comment)
Currently using
IterableDataset
s in the SFTTrainer is not plausible because of two issues:trl/trl/trainer/sft_trainer.py
Lines 404 to 405 in 1705aeb
trl/trl/trainer/sft_trainer.py
Lines 520 to 522 in 1705aeb
This PR remedies both those issues by explicitly checking whether the dataset is a datasets.IterableDataset in point 1, and also in point 2 in which case the packed ConstantLengthDataset is returned as-is (which is also an IterableDataset).
These changes should make the SFTTrainer better compatible with streaming datasets. The motivation to improve the situation is because in the alignment handbook we also use the SFTTrainer for continued pretraining where massive datasets (streamed) should be supported.
Note: seems that the epoch calculation does not happen correctly when max_steps is given though. With batch size 2, accum. steps 8 and optim steps 7153 with 16 gpus, I get this strange calculation for num examples and num epochs:
closes #1455