Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow streaming (datasets.IterableDataset) #1468

Merged
merged 6 commits into from
Apr 11, 2024
Merged

Allow streaming (datasets.IterableDataset) #1468

merged 6 commits into from
Apr 11, 2024

Conversation

BramVanroy
Copy link
Contributor

@BramVanroy BramVanroy commented Mar 22, 2024

The motivation for this PR is given in this issue: #1455 (comment)

Currently using IterableDatasets in the SFTTrainer is not plausible because of two issues:

  1. an IterableDataset is a subclass of a torch dataset, and therefore will not be considered when deciding whether to use packing

if isinstance(dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset)):
return dataset

  1. when packing, a given dataset is exhausted and loaded fully in-memory with Dataset.from_generator - which defeats the purpose of having an IterableDataset in the first place

packed_dataset = Dataset.from_generator(
data_generator, gen_kwargs={"constant_length_iterator": constant_length_iterator}
)

This PR remedies both those issues by explicitly checking whether the dataset is a datasets.IterableDataset in point 1, and also in point 2 in which case the packed ConstantLengthDataset is returned as-is (which is also an IterableDataset).

These changes should make the SFTTrainer better compatible with streaming datasets. The motivation to improve the situation is because in the alignment handbook we also use the SFTTrainer for continued pretraining where massive datasets (streamed) should be supported.

Note: seems that the epoch calculation does not happen correctly when max_steps is given though. With batch size 2, accum. steps 8 and optim steps 7153 with 16 gpus, I get this strange calculation for num examples and num epochs:

***** Running training *****
Num examples = 1,831,168
Num Epochs = 9,223,372,036,854,775,807
Instantaneous batch size per device = 2
Total train batch size (w. parallel, distributed & accumulation) = 256
Gradient Accumulation steps = 8
Total optimization steps = 7,153

closes #1455

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @BramVanroy for the detailed work on this ! I think there should be no harm supporting this in SFTTrainer, can you run the styling checks? make precommit - then we can merge imo

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BramVanroy
Copy link
Contributor Author

BramVanroy commented Apr 10, 2024

Thanks a lot @BramVanroy for the detailed work on this ! I think there should be no harm supporting this in SFTTrainer, can you run the styling checks? make precommit - then we can merge imo

Do you have any thoughts on the incorrect reporting of number of epochs?

@younesbelkada
Copy link
Contributor

@BramVanroy hmmm not sure what we can do here to be honest :/

@BramVanroy
Copy link
Contributor Author

@BramVanroy hmmm not sure what we can do here to be honest :/

Hm yeah, maybe let's keep it like that for now then.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again !

@younesbelkada younesbelkada merged commit e667550 into huggingface:main Apr 11, 2024
9 checks passed
@BramVanroy
Copy link
Contributor Author

@younesbelkada I probably should have included it in this PR, too, but there is another useful (imo) change that is just three lines of code that could make the SFTTrainer even more flexible. I added it to a separate PR: #1520

lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* safe-guard iterabledatasets

* import datasets

* reference the correct IterableDataset

* make pre-commit
@snow-kartikbagalore
Copy link

snow-kartikbagalore commented Oct 12, 2024

@BramVanroy @younesbelkada Hello guys, amazing work on this repo.
I wanted to know if it is possible to pass num_train_epochs now? Or is max_steps the recommended way to go about this for now?
I am working with a very large dataset, and am streaming the dataset by passing it as a ConstantLengthDataset.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support streaming + packing in SFTTrainer
4 participants