-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add early support for torchdata.stateful_dataloader.StatefulDataLoader
within the Accelerator
#2895
Conversation
…lerate into stateful-dataloader
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
It's tricky. My understanding of the problem is this: We want to do something like this"If However, based on how classes like I have one more idea for a solution, which is to have those classes create a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The more I've been working on this, the more I actually think this is the best solution we can get. Thanks a bunch for doing this, I think even though it's annoying with the patches, there's no other clear way to get there.
@byi8220 can you resolve the PR's and then I think we're okay to merge this. |
As a final step, we likely want to update |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR and the good discussion of possible designs. The end result is still something that I'm afraid will one day cause a hard to debug issue, but I can't say what exactly would be a better solution.
I added a couple of comments, which I think can help clean up this PR a bit, but don't consider them to be blockers. I have to admit I only skimmed the tests but they look very well done, so together with the existing ones they should hopefully avoid regressions.
One thing I would like to see is an addition to the docs to explain what stateful data loaders are, why users may want to use them, and how they can use them.
base_cls = self.__class__ | ||
base_cls_name = self.__class__.__name__ | ||
parent_cls_name = self.base_dataloader.__class__ | ||
self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me just bring up (again) that another solution could be monkey-patching __instancecheck__
on DataLoader
. Not saying that it's less hacky, just wanted to raise awareness :)
src/accelerate/data_loader.py
Outdated
for attr in self.base_dataloader.__dict__.keys(): | ||
setattr(self, attr, getattr(self.base_dataloader, attr)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kinda looks dangerous. For example, this skips @property
, is that intended? We could instead use __getattr__
to dispatch to self.base_dataloader
.
If we want to stick this this, more succinct code could be: self.__dict__.update(self.base_loader.__dict__)
or vars(self).update(self.base_loader.__dict__)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kinda looks dangerous.
Kinda agree with you, but all dynamic reflection looks dangerous to me.
I did write up an alternative which avoids the wizardry and just duplicates all the code required over here in: byi8220/accelerate@stateful-dataloader...byi8220:accelerate:stateful-dataloader-2
That code is messier and involves way more duplication, but much more explicit in what it does. If enough people feel the reflection approach is way too hacky and this feature doesn't justify it, I'm fine with doing that instead.
We could instead use getattr to dispatch to self.base_dataloader.
I updated the PR to do that instead.
src/accelerate/data_loader.py
Outdated
super().load_state_dict(state_dict) | ||
self.dl_state_dict = self.state_dict | ||
|
||
def _save_state_dict(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, the name is not quite fitting, isn't it more like update_state_dict
or so? Also, maybe we can avoid this all by not having a static self.dl_state_dict
attribute but instead the state_dict
method just returns self.base_dataloader.state_dict()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, the name is not quite fitting, isn't it more like update_state_dict or so?
Changed to _update_state_dict
Also, maybe we can avoid this all by not having a static self.dl_state_dict attribute but instead the state_dict method just returns self.base_dataloader.state_dict().
I'm not sure if we can. The base dataloader's state dict is one ahead of what we're yielding, so we couldn't do a passthrough. Some additional context in the comments of a6e192c#r1704736815
@BenjaminBossan Thanks for the review. Just addressed the comments on the PR.
The ultimate intent of this code is something like "Sometimes I want a DataLoaderDispatcher that inherits from DataLoader, but other times I want a DataLoaderDispatcher that inherits from StatefulDataLoader." Imo, the less magical alternative would be to explicitly duplicate each DataLoader derivative that accelerate defines into a stateful version. I.e. manually create the classes StatefulDataLoaderDispatcher, StatefulDataLoaderShard, StatefulSkipDataLoader. I wrote up this alternative in a separate branch (diffed by byi8220/accelerate@stateful-dataloader...byi8220:accelerate:stateful-dataloader-2), but it leads to quite a lot of code duplication and also looks messy.
I've tested this locally on my 1 GPU home workstation + a 2xGPU cloud instance (which costs me a few dollars every time I want to run the test suite 😞 ...) This is my first real PR into The tests highlighted one small thing though, the fact that to fully stop using a dataloader in the middle you have to call
Imo this might be better in a separate PR, once the code is checked in? |
I believe this has been a known "issue" in accelerate (I've seen it pop up in other issues sparingly). Agreed that it's less of an issue here, since this is pretty much just called once at the start of training. As long as we have the state properly (which your tests check!) it's a different bug/issue to solve
We tend to like full FC PR's that include doc updates. Less likely it'll be forgotten about and it's done all at once so users who want the bleeding edge can read immediately :) |
Well, I have no idea how to solve such a problem in python. In the C++ world this is what destructors and RAII are for, I guess.
Sure, added a footnote in the docs about this feature. Also since this feature is now stable in torchdata I added a requirement for torchdata>=0.8.0 in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates.
I'm not sure if we can. The base dataloader's state dict is one ahead of what we're yielding, so we couldn't do a passthrough. Some additional context in the comments of a6e192c#r1704736815
I see. If Zach is fine with the proposed solution, then we're good.
Well, I have no idea how to solve such a problem in python. In the C++ world this is what destructors and RAII are for, I guess.
There is the __del__
magic method in Python but let's not touch it.
|
||
def __getattr__(self, name): | ||
# Delegate attribute access to the internal dataloader | ||
return getattr(self.base_dataloader, name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit of an edge case: Let's also check if the name is not "base_dataloader"
, and if it is to raise an AttributeError
, to avoid an infinite recursion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you give a code example of how infinite recursion would happen here?
If I'm reading the python3 docs for __getattr__()
correctly, it states "Note that if the attribute is found through the normal mechanism, __getattr__()
is not called." IIUC, base_dataloader
should always be retrievable through the normal mechanism.
If I add the following block into test_dataloader_inheritance()
in test_data_loader.py
(without making any changes), the tests pass without causing an infinite recursion:
assert isinstance(skip_dl.base_dataloader, DataLoader)
assert isinstance(dl_shard.base_dataloader, DataLoader)
assert isinstance(dl_dispatcher.base_dataloader, DataLoader)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you give a code example of how infinite recursion would happen here?
Yes, that would be for the edge case of an attribute getting called on the class, i.e. before it is instantiated. In that case, the base_dataloader
attribute does not exist. Now you could say "who would do such a pernicious thing?", but it's a bug that actually happened in another project and for some reason DeepSpeed would do this (on a module, not a data loader, but let's rather be safe than sorry).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still not entirely sure how this could happen, but I added a check in __getattr__
.
sgtm
I see. Destructors in python don't seem very reliable, but my knowledge of the python memory model isn't great. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a final review and also more carefully reviewed the tests. I didn't find anything big, but a few minor things that could be improved. After that, this can be merged from my POV.
return _is_package_available("torchdata") | ||
|
||
|
||
# TODO: Remove this function once stateful_dataloader is a stable feature in torchdata. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this now be adjusted to use a version check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, modeled after other version checks in file
self.dl_state_dict = self.state_dict | ||
|
||
def _update_state_dict(self): | ||
if hasattr(self.base_dataloader, "state_dict"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a comment here when this needs to be called and with the context on why it's required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment here, kinda clunky though.
Thanks!
I fixed the nits above, but I also made one, maybe important, change, done in 74e2f53 Basically, I literally realized just now that I have been delegating the work of iteration to the superclass, instead of the backing dataloader. That felt wrong, so I did the commit above. To confirm, replacing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nicely done! 🤩
Thanks so much for all your hard work on this, let's get it merged in ✅
(And yes, I think it is a sensible thing to do rather than the getattr spaghetti)
as a next step I'll work on getting this working with |
What does this PR do?
Fixes #2859
This PR does the following:
use_stateful_dataloader
inDataLoaderConfiguration
. Passing this into the config makes it so that allDataLoader
s prepared and returned by the Accelerator areStatefulDataLoader
objects from the torchdata libraryDataLoaderAdapter
which can wrap around and act as either PyTorch'sDataLoader
, or other variants of it such asStatefulDataLoader
DataLoaderShard
,DataLoaderDispatcher
, andSkipDataLoader
to inherit fromDataLoaderAdapter
instead ofDataLoader
Testing
Added new unit tests to test that StatefulDataLoader can be dropped in and loaded and saved from.
Caveats
torchdata
package may have conflicts withaccelerate
, see Importingtorchdata.stateful_dataloader
causes the testcheck_seedable_sampler
to fail #2894torchdata
is not installed, all existing tests pass, suggesting this is not regressive.torchdata.stateful_dataloader.StatefulDataLoader
is only available in the beta build oftorchdata
, this is not a stable feature.DataLoaderAdapter
is somewhat invasive and uses some questionable reflectionBefore submitting
Pull Request section?
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@muellerzr