-
Notifications
You must be signed in to change notification settings - Fork 391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trainers: num_workers > 0 results in pickling error on macOS/Windows #886
Comments
Hi @mohscorpion, I think the formatting of your paste is messed up, do you mind trying again? Also, can you post the entire file you are using (including imports) and the entire stack trace of the error? |
For now the EuroSat dataset doesn't require a geo sampler so removing that should work. |
If you follow this tutorial, but use the EuroSat dataset, then it should work https://torchgeo.readthedocs.io/en/latest/tutorials/trainers.html. |
i have commented that line out because of error , but still the problem remains |
my code is not essentially different the only difference is classifier task instead of regressor |
Can you post a properly formatted version of your code and error? |
Here's a demo script https://gist.github.com/calebrob6/2e111a61fe8e6b531d9a0844a79e9d30 that uses torchgeo version 0.3.1 I used a conda environment I created with |
Fixed the formatting. Code blocks require triple backticks, see here. |
I can't reproduce, the above code (minus the dataset/dataloader that are not needed) works fine for me. |
This is likely a multiprocessing issue. The reason @calebrob6 can't reproduce this is because multiprocessing uses a different start method on macOS/Windows vs. Linux: https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods If I'm correct, I should be able to reproduce this on macOS. Let me give it a shot. |
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchgeo.datamodules import EuroSATDataModule
from torchgeo.datasets import stack_samples, EuroSAT
from torchgeo.samplers import RandomGeoSampler
from torchgeo.trainers import ClassificationTask
euro_root = "./Eurosat/"
eurosat = EuroSAT(euro_root, split="train", download=False)
dataloader = DataLoader(eurosat, batch_size=128, collate_fn=stack_samples)
num_classes = 10
channels = 13
num_workers = 4
batch_size = 4
backbone = "resnet50"
weights = "imagenet"
lr = 0.01
lr_schedule_patience = 5
epochs = 50
datamodule = EuroSATDataModule(
root_dir=euro_root,
batch_size=batch_size,
num_workers=num_workers,
)
task = ClassificationTask(
classification_model=backbone,
weights=weights,
num_classes=num_classes,
in_channels=channels,
loss="ce",
learning_rate=lr,
learning_rate_schedule_patience=lr_schedule_patience
)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor="val_loss",
save_top_k=1,
save_last=True,
)
early_stopping_callback = pl.callbacks.EarlyStopping(
monitor="val_loss",
min_delta=0.00,
patience=10,
)
trainer = pl.Trainer(
callbacks=[checkpoint_callback, early_stopping_callback],
max_epochs=epochs
)
trainer.fit(model=task, datamodule=datamodule)
test_metrics = trainer.test(model=task, datamodule=datamodule) |
i will try on linux and report back |
Yep, I'm seeing the same issue on macOS. For now, a quick workaround is to set |
Full stack trace for anyone curious:
|
The thing that's odd to me is that multiprocessing (and therefore pickling) only happens within the data loader, but there is no batch norm in the dataset/data module. It's almost like it's trying to pickle the ResNet inside the dataset for some reason... |
i can confirm now , it runs ok on linux |
Bit obvious, but this erorr now also happens in the trainers tutorial for folks running MacOS (and probably also windows) . This seems to work: trainer = pl.Trainer(
accelerator="mps",
devices=1,
callbacks=[checkpoint_callback, early_stopping_callback],
logger=[csv_logger],
default_root_dir=experiment_dir,
min_epochs=1,
max_epochs=10,
fast_dev_run=in_tests
) Although there are warnings: /Users/calkoen/miniconda3/envs/torchgeo/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn( |
Spack has two types of tests:
The former run in serial, but the latter run in parallel and include our tutorials. Once we fix this issue, we should probably also start running our integration tests on macOS and Windows as well so we can prevent this issue from coming back. |
I get the same error on windows. But I'm using xView2 dataset for semantic segmentation. It is working on colab, but not on windows.
|
Further minimized the bug reproducer: from pytorch_lightning import Trainer
from torchgeo.datamodules import EuroSATDataModule
from torchgeo.trainers import ClassificationTask
datamodule = EuroSATDataModule(
root="tests/data/eurosat",
num_workers=4,
)
model = ClassificationTask(
model="resnet18",
weights="random",
num_classes=2,
in_channels=13,
loss="ce",
learning_rate=0.01,
learning_rate_schedule_patience=5,
)
trainer = Trainer(max_epochs=1)
trainer.fit(model=model, datamodule=datamodule) Doesn't get much smaller than that. Interestingly, the following does not exhibit the same issue: datamodule.setup()
trainer.fit(model=model, train_dataloaders=datamodule.train_dataloader()) I'm pretty confident this is a PyTorch Lightning bug. I'm trying to reproduce this outside of TorchGeo, but haven't gotten it working yet. Will let you know if I figure this out. |
The following also raises the same error, which confirms that PyTorch Lightning is trying to pickle the model for some reason: import pickle
from torchgeo.trainers import ClassificationTask
model = ClassificationTask(
model="resnet18",
weights="random",
num_classes=10,
in_channels=3,
loss="ce",
learning_rate=0.01,
learning_rate_schedule_patience=5,
)
pickle.dumps(model) |
Interestingly, the following does not raise an error: import pickle
from torchvision.models import resnet18
model = resnet18()
pickle.dumps(model) |
@calebrob6 you're going to love this. Remember when I complained the other day about how some of our transforms are done in On macOS/Windows, the default multiprocessing start method requires all objects necessary to run the subprocess to be pickleable. The LightningModule isn't pickleable, but normally it isn't necessary for parallel data loading, so that's fine. However, our DataLoaders use transforms that include a reference to an instance method of the LightningDataModule. The LightningDataModule itself is pickleable, but during training it acquires a reference back to the LightningModule, making it no longer pickleable. Chaos ensues. So the real "bug" is that LightningModules aren't pickleable even though the models they contain are pickleable. However, there's also an obvious workaround for this which is to not use instance methods during data loading. I'll open a bug report with the PyTorch Lightning folks, but I'll also open a PR here to remove all Thanks @mohscorpion @FlorisCalkoen @Seyed-Ali-Ahmadi for reporting this bug, and sorry it took so long to track down! |
Update: LightningModules are pickleable as long as you don't hack your import space. It turns out all of the: BatchNorm2d.__module__ = "nn.BatchNorm2d" stuff we have littered throughout TorchGeo is the reason that our trainers can't be pickled. This stuff was in there to fix the docs. I'm going to try to remove as much of it as I can and see if it's still needed with the latest version of Sphinx. Even if that fixes it, I might still remove the P.S. Apologies to the PyTorch Lightning devs for assuming that the bug was in their code and not ours! |
#976 is sufficient to fix ClassificationTask pickling, but I run into a new pickling issue when training on EuroSAT:
Didn't bother digging into this too much, but replacing |
@adamjstewart what about EuroSat is different from other datamodules that work? I'm asking because EuroSat seems like a very simple datamodule that shouldn't cause problems (which makes me suspicious that we don't fully understand what is going on). |
None of our data modules work in parallel on macOS/Windows. #992 will fix that. |
I know, but what is going on with "I run into a new pickling issue when training on EuroSAT" -- this smells suspicious. |
Let me do some digging... |
So the above error happens during sanity checks to the validation set. If you add
This didn't turn out to be helpful, but while digging around, I did notice a couple things that look sus:
Could it be possible that one of the components (Trainer, LightningModule) becomes unpickleable, but only during fit/validate? All 3 components are pickleable before and after fit/validate. The only reason that this bug surfaces at the moment is because the data loader points back to the LightningDataModule (because of preprocess), which points to Trainer/LightningModule. I'm not really sure how to dig further without help from the PL devs. We know one possible fix (#992). It's possible that Trainer/LightningModule is temporarily unpickleable, but I also don't think they ever intended for anyone to try to pickle it in the first place. |
Is this even unique to torchgeo or is it just a PL problem in general? Wonder if this happens just when training using a PL CIFAR10 example on Windows. |
If you use data module instance methods in your parallel data loader (like we do in TorchGeo) then you should be able to reproduce it. |
These might be related to the bug we were experiencing: |
i have a problem running torchgeo on eurosat MS tif images i wrote a simple code :
but i get this error on fit line :
am i missing a preparing step or something ?
help would be appreciated, by the way tutorials are still lacking essential sample codes and documentations like this issue on MS images.
thanks
The text was updated successfully, but these errors were encountered: