-
Notifications
You must be signed in to change notification settings - Fork 9
Fix GPU memory leak that came up in RCN example #97
Conversation
Codecov Report
@@ Coverage Diff @@
## master #97 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 8 8
Lines 631 633 +2
=========================================
+ Hits 631 633 +2
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
Seems like using numpy arrays instead of jax arrays in init functions fixed the error.
jax.device_put(self.value), | ||
{name: jax.device_put(data)}, | ||
self.fg_state, | ||
), | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not expert on jax but for my own clarification.
Do we not need to clear the variables self.value
and data
from jax memory memory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.value
and data
would stay on CPU. Intermediate objects on GPU will be cleared automatically if you use the platform
option.
The leak came from a failure to clear large jit compilation cache (the cache is large because it involves a large constant array, i.e. the wiring). A completely leak-free way is to include wiring as part of the arguments, but in my experiments that leads to slow compiling and hurts performance. Current solution is a bit of a compromise but should suffice for most cases.
The underlying issue seems to be known yet unresolved (jax-ml/jax#282).
Disabling
jit
and being careful about using jax arrays outside the inference function get rid of the leak. Inference is becoming a bit slower (sometimes as much as 2x slower) due to disablingjit
forrun_bp
.Also fixed a mypy issue (relaxed a
Hashable
toAny
) and reduced precommit autoupdate frequence.