Skip to content

Commit

Permalink
Revert "Pass less state to jax generate function (keras-team#1398)"
Browse files Browse the repository at this point in the history
This reverts commit c49bf9b.
  • Loading branch information
mattdangerw committed Jan 24, 2024
1 parent c41e844 commit e301142
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions keras_nlp/models/generative_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools

import tensorflow as tf
import tree

Expand Down Expand Up @@ -80,34 +82,51 @@ def wrapped_generate_function(

@jax.jit
def compiled_generate_function(inputs, end_token_id, state):
# The only state we update during generation is sampler state,
# all weights are fixed and will not change.
mapping = zip(self._sampler.variables, state)
(
sampler_variables,
trainable_variables,
non_trainable_variables,
) = state
mapping = itertools.chain(
zip(self._sampler.variables, sampler_variables),
zip(self.trainable_variables, trainable_variables),
zip(self.non_trainable_variables, non_trainable_variables),
)

with keras.StatelessScope(state_mapping=mapping) as scope:
outputs = self.generate_step(inputs, end_token_id)

# Get updated sampler variables from the stateless scope.
state = []
sampler_variables = []
for v in self._sampler.variables:
new_v = scope.get_current_value(v)
state.append(new_v if new_v is not None else v)
sampler_variables.append(new_v if new_v is not None else v)
state = (
sampler_variables,
trainable_variables,
non_trainable_variables,
)
return outputs, state

def wrapped_generate_function(
inputs,
end_token_id=None,
):
# Create an explicit tuple of all variable state.
state = (
self._sampler.variables,
self.trainable_variables,
self.non_trainable_variables,
)
inputs = tree.map_structure(ops.convert_to_tensor, inputs)
outputs, state = compiled_generate_function(
inputs,
end_token_id,
self._sampler.variables,
state,
)
# Only assign the sampler variables (random seeds), as other
# model variables should never be updated in generation.
for ref_v, v in zip(self._sampler.variables, state):
for ref_v, v in zip(self._sampler.variables, state[0]):
ref_v.assign(v)
return outputs

Expand Down

0 comments on commit e301142

Please sign in to comment.