Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory consumption does not seem to reflect the indicated consumption of nvidia-smi #37

Closed
duag opened this issue Feb 6, 2020 · 9 comments
Assignees

Comments

@duag
Copy link

duag commented Feb 6, 2020

  • MemCNN version: 1.2.0
  • Python version: 3.7.6
  • Operating System: linux-64

Description

Hello
I tried the memcnn.train examples to test the memory consumption of ResNet32 and RevNet38 on gpu.
The test script printed the following memory allocation estimates:
ResNet32: Model memory allocation: 1966080; ActMem 259687424.000 (259482711.720) (iteration 1100)
RevNet38: Model memory allocation: 2414592; ActMem 106352640.000 (106228648.514) (iteration 1100)

Yet nvidia-smi returned the following GPU Memory Usage stats, which didn't change during the training:
ResNet32: 712MiB
RevNet38: 810MiB

So from this test, I cannot see how this implementation of the RevNet architecture actually saves memory.
Can someone please explain to me if they managed to actual save memory (according to nvidia-smi) and how their setup looks like?

Furthermore is a memory-efficient implementation in PyTorch even possible?
The Google Brain Team recently published Trax which implements reversible layers with the Jax backend.
They argued that they had to use Jax instead of TF2 because they were not able to implement actual memory-efficent reversible layers with TF2. ( https://gitter.im/trax-ml/community?at=5e273d223b40ea043c7a1d8b )

What I Did

I tried the examples with PyTorch 1.1.0 and PyTorch 1.3.1 (with Cuda 10.0 and Cudnn 7.6.5):
python -m memcnn.train resnet32 cifar100
python -m memcnn.train revnet38 cifar100

@duag
Copy link
Author

duag commented Feb 6, 2020

Small update:
So I found out, that PyTorch keeps the memory reserved and does not actually free the memory on the GPU.
One can force PyTorch to free the memory by calling torch.cuda.empty_cache().
So I added this function call after input_tensor.storage().resize_(0) inside the InvertibleModuleWrapper forward and inverse function (forwad input discard, inverse input discard).

Now if I call watch -n 0.1 nvidia-smi I believe I can see the actual gpu memory usage of PyTorch with RevNet38.
The indicated GPU Memory Usage varies between ~530MiB to 760MiB, which is still way higher than I thought it should be with a memory-efficent implementation of RevNet.

I wish to use RevNet to train with a greater mini batch size than I am able to with ResNet (caused by memory limitations).
Now according to the idea behind reversible layers I should be able to do so, but I'm unable to find a way to bring PyTorch to reserve less memory, than it would need for a normal forward and backward pass, even if I use the memcnn.InvertibleModuleWrapper.

@silvandeleemput
Copy link
Owner

@duag Hi, thank you for trying out MemCNN. I see that some things regarding memory usage and the potential memory savings through the InvertibleModuleWrapper are not all that clear. I might actually add some text to the documentation somewhere since this comes up with more users.

As you already figured out you cannot rely on nvidia-smi for the actual used memory by PyTorch, since PyTorch caches freed memory so it won't show up there. See the PyTorch docs here for a good explanation of the memory management: https://pytorch.org/docs/stable/notes/cuda.html#memory-management.

For the statistics in the experiments I rely on the cuda.memory_allocated() method to get an accurate read on the actual allocated memory on the GPU by PyTorch. However, we need to distinguish two types of memory allocation: those for the network weights and those for the featuremap activations during training.

network weights
As you might know, the weights (trainable parameters) of a neural network take up a non-neglectible part of memory. The more layers and filters, the higher the required memory. This is the Model memory allocation in the experiment output.

featuremap activations
The most important memory saving we can achieve using the memcnn.InvertibleModuleWrapper is actually regarding the featuremap activations. During network training we usually need to store the activation in memory because it is needed for the gradient computations on the backward pass. Since the InvertibleModuleWrapper can reconstruct the input from its output the memory storage becomes O(1) memory storage complexity. This is the ActMem in the experiment output.

So with this background, looking at the experiments you first posted:

ResNet32: Model memory allocation: 1966080; ActMem 259687424.000 (259482711.720) (iteration 1100)
RevNet38: Model memory allocation: 2414592; ActMem 106352640.000 (106228648.514) (iteration 1100)

We can see that while RevNet38 uses more memory for the network weights 2414592/1966080 (23% more), it significantly saves memory for the activations 106352640/259687424 (it saves 60%). Also, note that the overhead for the activations is a factor +/-50 larger than that of the network weights.

Still, you might have expected a little more performance gain. Well, ResNet32 and RevNet38 are relatively small. The performance gain for the activations becomes more apparent when having larger models with more layers. You should be able to see this for training ResNet110 vs RevNet110 and ResNet164 vs RevNet164. So I would encourage you to try that out!

Regarding your experiments with the increased batch-size, I would expect that increasing the batch-size should allow you to reduce even more memory overhead, since the featuremaps typically get bigger, but you only need to store a few. Did you already try this out?

One thing that I can think of that limits the memory performance are the max-pools in the RevNet which require storing all its featuremaps activations (these are not invertible), before downsampling. If you're feeling adventurous you might even want to try to implement an "invertible max-pool" as in the i-revnet paper (see https://arxiv.org/abs/1802.07088). This should be easy to do by creating a torch.Module with a forward and an inverse operation and then wrapping it with the memcnn.InvertibleModuleWrapper.

Good luck with your project! Let me know where I could help.

@silvandeleemput silvandeleemput self-assigned this Feb 6, 2020
@duag
Copy link
Author

duag commented Feb 7, 2020

@silvandeleemput thank you for your detailed reply.
I did already try out using bigger batch sizes and failed to do so.
I believe the biggest problem is, that PyTorch wants to reserve the memory it expects to use even it does not have to use it with invertible layers.
So if one rises the batch size or input size PyTorch will try to reserve more memory (in every layer it wants to calculate the gradient for) even if the allocated memory will be way lower.
And therefor PyTorch will fail to reserve memory if it exceeds the maximum available memory.

To elaborate a little bit on PyTorch's memory reserve behaviour:
PyTorch 1.4. comes with a couple of new features to analyse memory usage.
Among others with a function to analyse how much memory PyTorch reserves at the moment:
https://pytorch.org/docs/stable/cuda.html#torch.cuda.memory_reserved
So I tried PyTorch 1.4 with memcnn and while the cuda.memory_allocated() did report similar results, cuda.memory_reserved() did show how much memory PyTorch reserves during the training and what the actual memory usage is, since when there is not enough memory for PyTorch to reserve, PyTorch will crash.
So I printed the torch.cuda.memory_reserved(device) - model_mem_allocation output at different steps during the training (inside the classification.py file).
When I started the training without the torch.cuda.empty_cache() modification this is how the first three iterations look like (for RevNet38):
Step 0:
after model_alloc_estimation: 0
after model.train(): 20971520
after loss estimation: 134217728
after backwards pass: 320864256
Step 1:
after model.train(): 320864256
after loss estimation: 320864256
after backwards pass: 350224384
Step 2:
after model.train(): 350224384
after loss estimation: 350224384
after backwards pass: 379584512
This was the moment the reserved memory stopped rising.

And this was the output with the torch.cuda.empty_cache() modification:
Step 0:
after model_alloc_estimation: 0
after model.train(): 20971520
after loss estimation: 134217728
after backwards pass: 320864256
Step 1:
after model.train(): 320864256
after loss estimation: 127926272
after backwards pass: 360710144
Step 2:
after model.train(): 360710144
after loss estimation: 130023424
after backwards pass: 341835776

Now this are only estimates after PyTorch went through the network.
I did most of my work with TensorFlow and I am fairly new to Pytorch so my question is:
Is it possible to stop PyTorch from reserving so much memory or to free it for every layer right after it used the layer for the backwards pass?

@silvandeleemput
Copy link
Owner

@duag Thanks for looking into this. PyTorch does indeed not always play very nice with MemCNN since it does not expect on the fly emptying of the memory storages. You might want to play around a bit with the empty_cache method in the InvertibleModuleWrapper backward_hook and during the forward pass:

on forward
You might want to try to call torch.cuda.empty_cache() between the following lines

y.detach_() # Detaches y in-place (inbetween computations can now be discarded)
y.requires_grad = self.training

on backward
You might want to try to call torch.cuda.empty_cache() between the following lines (line 34):

input_tensor.set_(input_inverted)
# compute gradients

This should at least clear all unused cached information during the forward and backward passes. Let me know if it helps!

@ZixuanJiang
Copy link

ZixuanJiang commented Feb 12, 2020

Thanks for your great project and detailed answers.

I have measured the allocated memory after each optimizer.step (line 113 as shown below).

optimizer.step()

Similar to duag's results, the memory after optimizer.step is larger than that after loss computation for reversible networks. This issue does not appear to the non-reversible networks.

My environment is Python 3.7.4, PyTorch 1.4.0. CUDA 10.1, and a batch size of 256.
For resnet164, the sizes of allocated memory (MiB) after forward and optimizer are stable across iterations, as shown below.

2491.66650390625 21.1982421875
2498.64501953125 21.19970703125
2498.25439453125 21.19970703125
2498.50439453125 21.19970703125
2498.89501953125 21.19970703125
2498.50439453125 21.19970703125
2498.89501953125 21.19970703125
2498.50439453125 21.19970703125
2498.89501953125 21.19970703125
2498.50439453125 21.19970703125

For revnet164, the sizes of allocated memory (MiB) after forward and optimizer have high variance throughout iterations, as shown below.

450.22900390625 1528.8671875
457.8759765625 1527.94677734375
709.8759765625 1776.94677734375
1058.1259765625 2130.19677734375
1411.1259765625 2480.19677734375
1759.1259765625 2830.19677734375
2109.8759765625 3178.94677734375
2458.8759765625 3529.94677734375
459.8759765625 1528.94677734375
459.1259765625 1530.19677734375
809.8759765625 1878.94677734375
1058.8759765625 2129.94677734375
1309.8759765625 2378.94677734375
1559.1259765625 2630.19677734375
1809.8759765625 2878.94677734375
2058.8759765625 3129.94677734375
2059.8759765625 3128.94677734375
2309.1259765625 3380.19677734375
2559.8759765625 3628.94677734375
2808.8759765625 3879.94677734375

In short, could you try to fix the following issues for reversible architectures?

  • The allocated memory has a large variance.
  • The peak memory of reversible architecture is larger than its non-reversible counterpart.
  • The memory after optimizer.step is larger than that after loss computation.

Many thanks for your patience and contributions.

@duag
Copy link
Author

duag commented Feb 17, 2020

@silvandeleemput thank you for your answer.
I tried it out but I still got that oscillating memory reservation behavior with peaks higher than the ResNet version.

So I tried to find a reason for what is happening, tried to analyze the behaviour and compared memcnn with other memory saving PyTorch implementations.
First of all this is mostly just guesswork because I just skimmed through the cpp code of PyTorch, since PyCharms debugger stopps at the Variable._execution_engine.run_backward(...) call inside autograd.

When I stepped through the code with the debugger, RevNet used much less memory than a normal run and the memory usage did not increase over further iterations.
So it seems that the memory reservation of PyTorch is a result of PyTorch trying to optimize the graph taking runtimes into account.

Now I am guessing PyTorch's problem lies somewhere here:

y.requires_grad = self.training

By telling autograd that this tensor requires a grad but never actually calling backward() on the original output_tensor inside the backward_hook, PyTorch might still try to reserve memory or try to optimize the graph for this tensor, even though it is not necessary.

I looked up how PyTorch implements gradient checkpointing since the problem is somewhat similar, but demonstrates expected memory reservation behaviour.
(first do a forward pass through most of the model without storing the activations and than during the backwards pass recalculate the activations and do the backwards pass through a part of the model;
for further reading:
https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9
and the PyTorch implementation:
https://pytorch.org/docs/stable/checkpoint.html
https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html )
They solved the problem this way:
During the forward pass they only add the checkpoint tensors to the graph by setting requires_grad=True, every other tensor will be marked with requires_grad=False.
During the backwards_pass they then recalculate the missing activations between two checkpoints and this time the tensors will be marked.
This way they will not have any tensors inside the graph, which are marked for grad but are not called by autograd during the backward pass.

If one wants to implement a similar solution in memcnn, this would of course mean that the backward_hook can't be used as it is now, because a backward_hook can't be added to a tensor if it is not marked for a gradient.
So I do not know if a somewhat similar solution is possible with memcnn, but I think this could maybe work:

  • Use the ctx variable from torch.autograd.Fucntion inside the forward() function and save the needed variables inside it (for example the variables which are the backward_hook parameters)
  • Implement a backward function inside the InvertibleModuleWrapper, which can access the needed tensors from ctx, recalculate the input from the old output, recalculate the new output with grad, call backward on the new output

@silvandeleemput
Copy link
Owner

@duag Thanks for looking into this issue thoroughly, your feedback is very helpful for solving this issue. It seems some kind of workaround is indeed needed to avoid the reported memory behavior by PyTorch, and I'll have a look at your suggestion.

It can take a while though since I am currently very occupied with my work. I hope to find some time for this in March.

silvandeleemput added a commit that referenced this issue Mar 1, 2020


* Cleaned two print statements from test_manager.py
* Removed backward hook code from revop.py and replaced it with the ReversibleFunction autograd function
* Removed unnecesary pathlib2 dependency from test_log.py
* Changed test_revop.py to allow for identity inverse modules
silvandeleemput added a commit that referenced this issue Mar 1, 2020
* Renamed ReversibleFunction to InvertibleCheckpointFunction
  * Clearing of output ref in InvertibleCheckpointFunction is crucial for enabling memory saving
  * Added output ref count for InvertibleCheckpointFunction (on multiple backward passes)
* Modified CPU memory limit for test_memory_saving.py
* Changed one of the test_revop.py tests to expect two backward passes instead of one
@silvandeleemput
Copy link
Owner

silvandeleemput commented Mar 1, 2020

@duag @ZixuanJiang Ok, I have just released MemCNN 1.3.0 which should fix the issue. Please let me know if it works for you.

EDIT: It is better to use MemCNN 1.3.1 since it solves an issue with memory spikes. This has been released as well.

@silvandeleemput
Copy link
Owner

@duag @ZixuanJiang As far as I can tell, this issue has been resolved since MemCNN 1.3.1. Please reopen this thread if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants