Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Fix GPU memory leak that came up in RCN example #97

Merged
merged 4 commits into from
Nov 24, 2021

Conversation

StannisZhou
Copy link
Contributor

The underlying issue seems to be known yet unresolved (jax-ml/jax#282).

Disabling jit and being careful about using jax arrays outside the inference function get rid of the leak. Inference is becoming a bit slower (sometimes as much as 2x slower) due to disabling jit for run_bp.

Also fixed a mypy issue (relaxed a Hashable to Any) and reduced precommit autoupdate frequence.

@StannisZhou StannisZhou added the bug Something isn't working label Nov 24, 2021
@StannisZhou StannisZhou added this to the 0.2.1 milestone Nov 24, 2021
@StannisZhou StannisZhou self-assigned this Nov 24, 2021
@StannisZhou StannisZhou mentioned this pull request Nov 24, 2021
@codecov-commenter
Copy link

codecov-commenter commented Nov 24, 2021

Codecov Report

Merging #97 (dec9c70) into master (2f389ab) will not change coverage.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff            @@
##            master       #97   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files            8         8           
  Lines          631       633    +2     
=========================================
+ Hits           631       633    +2     
Impacted Files Coverage Δ
pgmax/fg/graph.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 2f389ab...dec9c70. Read the comment docs.

Copy link
Contributor

@shrinuKushagra shrinuKushagra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.
Seems like using numpy arrays instead of jax arrays in init functions fixed the error.

jax.device_put(self.value),
{name: jax.device_put(data)},
self.fg_state,
),
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not expert on jax but for my own clarification.
Do we not need to clear the variables self.value and data from jax memory memory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.value and data would stay on CPU. Intermediate objects on GPU will be cleared automatically if you use the platform option.

The leak came from a failure to clear large jit compilation cache (the cache is large because it involves a large constant array, i.e. the wiring). A completely leak-free way is to include wiring as part of the arguments, but in my experiments that leads to slow compiling and hurts performance. Current solution is a bit of a compromise but should suffice for most cases.

@StannisZhou StannisZhou merged commit 6b90c58 into vicariousinc:master Nov 24, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants