From 84a489777fc1638aed4bcc82a855af1ed016d8f7 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Sat, 1 Feb 2025 19:14:14 -0800 Subject: [PATCH] Don't duplicate frozen parameters during predict() On the Jax backend we were not using donate_argnums during predict. This works when a model is mostly trainable, but when a model is mostly or all frozen, this will result in 2x the memory jump (which is why we use donate_argnums for fit and evaluate). This change adds donate_argnums to the predict function to avoid the memory spike. But because this means all incoming state (including the trainable variables) will be deleted by jax, this means we need to sync the trainable variables state much like in fit and evaluate. An alternative would be to change the predict_step signature (so we could only donate non-trainable variables), but this would be a breaking change and confusing. --- keras/src/backend/jax/trainer.py | 35 ++++++++++++++++---------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 7317761658e..7de7b613209 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -308,7 +308,7 @@ def predict_step(state, data): return outputs, (state[0], non_trainable_variables) if not self.run_eagerly and self.jit_compile: - predict_step = jax.jit(predict_step) + predict_step = jax.jit(predict_step, donate_argnums=0) _step_function = self._make_function( predict_step, concatenate_outputs=True @@ -316,7 +316,7 @@ def predict_step(state, data): def step_function(state, iterator): outputs, state = _step_function(state, iterator) - return outputs, state[1] + return outputs, state self.predict_function = step_function @@ -671,14 +671,20 @@ def append_to_outputs(batch_outputs, outputs): state = self._get_jax_state( trainable_variables=True, non_trainable_variables=True, + purge_model_variables=True, ) - self._purge_model_variables(non_trainable_variables=True) self._jax_state_synced = False - else: - state = (state[0], non_trainable_variables) - batch_outputs, non_trainable_variables = self.predict_function( - state, iterator - ) + batch_outputs, state = self.predict_function(state, iterator) + ( + trainable_variables, + non_trainable_variables, + ) = state + self._jax_state = { + "trainable_variables": trainable_variables, + # I wouldn't recommend modifying non-trainable model state + # during predict(), but it's allowed. + "non_trainable_variables": non_trainable_variables, + } outputs = append_to_outputs(batch_outputs, outputs) # Dispatch callbacks. This takes care of async dispatch. @@ -687,11 +693,6 @@ def append_to_outputs(batch_outputs, outputs): if self.stop_predicting: break - self._jax_state = { - # I wouldn't recommend modifying non-trainable model state - # during predict(), but it's allowed. - "non_trainable_variables": non_trainable_variables, - } self.jax_state_sync() callbacks.on_predict_end() self._jax_state = None @@ -819,10 +820,10 @@ def predict_on_batch(self, x): def data(): yield (x,) - batch_outputs, non_trainable_variables = self.predict_function( - state, data() - ) + batch_outputs, state = self.predict_function(state, data()) + trainable_variables, non_trainable_variables = state self._jax_state = { + "trainable_variables": trainable_variables, "non_trainable_variables": non_trainable_variables, } self.jax_state_sync() @@ -929,7 +930,7 @@ def _purge_model_variables( ): """Remove all the model variable for memory saving. - During JAX training, since the training function are stateless, we have + During JAX training, since the training function is stateless, we have to pass in and get the model weights over and over, during which the copy of the weights that attached to the Variable are still and occupying extra memory. We remove those variable to save memory (for