-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory consumption does not seem to reflect the indicated consumption of nvidia-smi #37
Comments
Small update: Now if I call I wish to use RevNet to train with a greater mini batch size than I am able to with ResNet (caused by memory limitations). |
@duag Hi, thank you for trying out MemCNN. I see that some things regarding memory usage and the potential memory savings through the InvertibleModuleWrapper are not all that clear. I might actually add some text to the documentation somewhere since this comes up with more users. As you already figured out you cannot rely on For the statistics in the experiments I rely on the network weights featuremap activations So with this background, looking at the experiments you first posted:
We can see that while RevNet38 uses more memory for the network weights 2414592/1966080 (23% more), it significantly saves memory for the activations 106352640/259687424 (it saves 60%). Also, note that the overhead for the activations is a factor +/-50 larger than that of the network weights. Still, you might have expected a little more performance gain. Well, ResNet32 and RevNet38 are relatively small. The performance gain for the activations becomes more apparent when having larger models with more layers. You should be able to see this for training ResNet110 vs RevNet110 and ResNet164 vs RevNet164. So I would encourage you to try that out! Regarding your experiments with the increased batch-size, I would expect that increasing the batch-size should allow you to reduce even more memory overhead, since the featuremaps typically get bigger, but you only need to store a few. Did you already try this out? One thing that I can think of that limits the memory performance are the max-pools in the RevNet which require storing all its featuremaps activations (these are not invertible), before downsampling. If you're feeling adventurous you might even want to try to implement an "invertible max-pool" as in the i-revnet paper (see https://arxiv.org/abs/1802.07088). This should be easy to do by creating a Good luck with your project! Let me know where I could help. |
@silvandeleemput thank you for your detailed reply. To elaborate a little bit on PyTorch's memory reserve behaviour: And this was the output with the Now this are only estimates after PyTorch went through the network. |
@duag Thanks for looking into this. PyTorch does indeed not always play very nice with MemCNN since it does not expect on the fly emptying of the memory storages. You might want to play around a bit with the on forward Lines 127 to 128 in e0a0288
on backward Lines 33 to 35 in e0a0288
This should at least clear all unused cached information during the forward and backward passes. Let me know if it helps! |
Thanks for your great project and detailed answers. I have measured the allocated memory after each memcnn/memcnn/trainers/classification.py Lines 112 to 113 in e0a0288
Similar to duag's results, the memory after optimizer.step is larger than that after loss computation for reversible networks. This issue does not appear to the non-reversible networks. My environment is Python 3.7.4, PyTorch 1.4.0. CUDA 10.1, and a batch size of 256.
For revnet164, the sizes of allocated memory (MiB) after forward and optimizer have high variance throughout iterations, as shown below.
In short, could you try to fix the following issues for reversible architectures?
Many thanks for your patience and contributions. |
@silvandeleemput thank you for your answer. So I tried to find a reason for what is happening, tried to analyze the behaviour and compared memcnn with other memory saving PyTorch implementations. When I stepped through the code with the debugger, RevNet used much less memory than a normal run and the memory usage did not increase over further iterations. Now I am guessing PyTorch's problem lies somewhere here: Line 128 in e0a0288
By telling autograd that this tensor requires a grad but never actually calling backward() on the original output_tensor inside the backward_hook, PyTorch might still try to reserve memory or try to optimize the graph for this tensor, even though it is not necessary. I looked up how PyTorch implements gradient checkpointing since the problem is somewhat similar, but demonstrates expected memory reservation behaviour. If one wants to implement a similar solution in memcnn, this would of course mean that the backward_hook can't be used as it is now, because a backward_hook can't be added to a tensor if it is not marked for a gradient.
|
@duag Thanks for looking into this issue thoroughly, your feedback is very helpful for solving this issue. It seems some kind of workaround is indeed needed to avoid the reported memory behavior by PyTorch, and I'll have a look at your suggestion. It can take a while though since I am currently very occupied with my work. I hope to find some time for this in March. |
* Renamed ReversibleFunction to InvertibleCheckpointFunction * Clearing of output ref in InvertibleCheckpointFunction is crucial for enabling memory saving * Added output ref count for InvertibleCheckpointFunction (on multiple backward passes) * Modified CPU memory limit for test_memory_saving.py * Changed one of the test_revop.py tests to expect two backward passes instead of one
@duag @ZixuanJiang Ok, I have just released MemCNN 1.3.0 which should fix the issue. Please let me know if it works for you. EDIT: It is better to use MemCNN 1.3.1 since it solves an issue with memory spikes. This has been released as well. |
@duag @ZixuanJiang As far as I can tell, this issue has been resolved since MemCNN 1.3.1. Please reopen this thread if needed. |
Description
Hello
I tried the memcnn.train examples to test the memory consumption of ResNet32 and RevNet38 on gpu.
The test script printed the following memory allocation estimates:
ResNet32: Model memory allocation: 1966080; ActMem 259687424.000 (259482711.720) (iteration 1100)
RevNet38: Model memory allocation: 2414592; ActMem 106352640.000 (106228648.514) (iteration 1100)
Yet nvidia-smi returned the following GPU Memory Usage stats, which didn't change during the training:
ResNet32: 712MiB
RevNet38: 810MiB
So from this test, I cannot see how this implementation of the RevNet architecture actually saves memory.
Can someone please explain to me if they managed to actual save memory (according to nvidia-smi) and how their setup looks like?
Furthermore is a memory-efficient implementation in PyTorch even possible?
The Google Brain Team recently published Trax which implements reversible layers with the Jax backend.
They argued that they had to use Jax instead of TF2 because they were not able to implement actual memory-efficent reversible layers with TF2. ( https://gitter.im/trax-ml/community?at=5e273d223b40ea043c7a1d8b )
What I Did
I tried the examples with PyTorch 1.1.0 and PyTorch 1.3.1 (with Cuda 10.0 and Cudnn 7.6.5):
python -m memcnn.train resnet32 cifar100
python -m memcnn.train revnet38 cifar100
The text was updated successfully, but these errors were encountered: