Skip to content

Commit

Permalink
Stop on multiple end tokens (keras-team#1518)
Browse files Browse the repository at this point in the history
* Add multitoken stopping

* Update gemma_causal_lm.py

* Add further multitoken support

* Formatting

* Revert tokenizer changes

* Move multi token stop to generative task

* None check

* None check

* Error message

* Add stop_token_ids

* Util testing

* Fix sampler tests

* All multitoken stop to all models

* Sampler multi token

* Formatting

* Tuple required

* Tuple docstring

* Pytorch GPU fix

* Numpy fix
  • Loading branch information
grasskin authored and abuelnasr0 committed Apr 2, 2024
1 parent 6ea1e63 commit 6e946e2
Show file tree
Hide file tree
Showing 21 changed files with 219 additions and 95 deletions.
18 changes: 10 additions & 8 deletions keras_nlp/models/bart/bart_seq_2_seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from keras_nlp.models.generative_task import GenerativeTask
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.tensor_utils import any_equal


@keras_nlp_export("keras_nlp.models.BartSeq2SeqLM")
Expand Down Expand Up @@ -398,7 +399,7 @@ def _build_cache(
def generate_step(
self,
inputs,
end_token_id=None,
stop_token_ids=None,
):
"""A compilable generation function for a batch of inputs.
Expand All @@ -412,8 +413,8 @@ def generate_step(
inputs: A dictionary with four keys - `"encoder_token_ids"`,
`"encoder_padding_mask"`, `"decoder_token_ids"` and
`"decoder_padding_mask"`, with batched tensor values.
end_token_id: The id of the end token to stop on. If all
sequences have produced a new `end_token_id`, generation
stop_token_ids: Tuple of id's of end token's to stop on. If all
sequences have produced a new stop token, generation
will stop.
"""
(
Expand Down Expand Up @@ -477,17 +478,18 @@ def repeat_tensor(x):
cache=self_attention_cache,
index=index,
mask=decoder_padding_mask,
end_token_id=end_token_id,
stop_token_ids=stop_token_ids,
hidden_states=hidden_states,
model=self,
)

# Compute an output padding mask with the token ids we updated.
if end_token_id is not None:
# Build a mask of `end_token_id` locations not in the original
if stop_token_ids is not None:
# Build a mask of `stop_token_ids` locations not in the original
# prompt (not in locations where `decoder_padding_mask` is True).
end_locations = ops.logical_and(
ops.equal(decoder_token_ids, end_token_id),
end_locations = any_equal(
decoder_token_ids,
stop_token_ids,
ops.logical_not(decoder_padding_mask),
)
end_locations = ops.cast(end_locations, "int32")
Expand Down
19 changes: 10 additions & 9 deletions keras_nlp/models/bloom/bloom_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from keras_nlp.models.bloom.bloom_presets import backbone_presets
from keras_nlp.models.generative_task import GenerativeTask
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.tensor_utils import any_equal


@keras_nlp_export("keras_nlp.models.BloomCausalLM")
Expand Down Expand Up @@ -245,7 +246,7 @@ def _build_cache(self, token_ids):
def generate_step(
self,
inputs,
end_token_id=None,
stop_token_ids=None,
):
"""A compilable generation function for a single batch of inputs.
Expand All @@ -256,8 +257,8 @@ def generate_step(
Args:
inputs: A dictionary with two keys `"token_ids"` and
`"padding_mask"` and batched tensor values.
end_token_id: The id of the end token to stop on. If all
sequences have produced a new `end_token_id`, generation
stop_token_ids: Tuple of id's of end token's to stop on. If all
sequences have produced a new stop token, generation
will stop.
"""
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
Expand Down Expand Up @@ -290,19 +291,19 @@ def next(prompt, cache, index):
cache=cache,
index=index,
mask=padding_mask,
end_token_id=end_token_id,
stop_token_ids=stop_token_ids,
hidden_states=hidden_states,
model=self,
)

# Compute an output padding mask with the token ids we updated.
if end_token_id is not None:
# Build a mask of `end_token_id` locations not in the original
if stop_token_ids is not None:
# Build a mask of stop token locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = ops.logical_and(
ops.equal(token_ids, end_token_id),
ops.logical_not(padding_mask),
end_locations = any_equal(
token_ids, stop_token_ids, ops.logical_not(padding_mask)
)

end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after end_locations.
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
Expand Down
19 changes: 10 additions & 9 deletions keras_nlp/models/gemma/gemma_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from keras_nlp.models.gemma.gemma_presets import backbone_presets
from keras_nlp.models.generative_task import GenerativeTask
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.tensor_utils import any_equal


@keras_nlp_export("keras_nlp.models.GemmaCausalLM")
Expand Down Expand Up @@ -238,7 +239,7 @@ def _build_cache(self, token_ids):
def generate_step(
self,
inputs,
end_token_id=None,
stop_token_ids=None,
):
"""A compilable generation function for a single batch of inputs.
Expand All @@ -249,8 +250,8 @@ def generate_step(
Args:
inputs: A dictionary with two keys `"token_ids"` and
`"padding_mask"` and batched tensor values.
end_token_id: The id of the end token to stop on. If all
sequences have produced a new `end_token_id`, generation
stop_token_ids: Tuple of id's of end token's to stop on. If all
sequences have produced a new stop token, generation
will stop.
"""
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
Expand Down Expand Up @@ -283,19 +284,19 @@ def next(prompt, cache, index):
cache=cache,
index=index,
mask=padding_mask,
end_token_id=end_token_id,
stop_token_ids=stop_token_ids,
hidden_states=hidden_states,
model=self,
)

# Compute an output padding mask with the token ids we updated.
if end_token_id is not None:
# Build a mask of `end_token_id` locations not in the original
if stop_token_ids is not None:
# Build a mask of `stop_token_ids` locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = ops.logical_and(
ops.equal(token_ids, end_token_id),
ops.logical_not(padding_mask),
end_locations = any_equal(
token_ids, stop_token_ids, ops.logical_not(padding_mask)
)

end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after end_locations.
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
Expand Down
5 changes: 1 addition & 4 deletions keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,7 @@ def generate_preprocess(
"padding_mask": padding_mask,
}

def generate_postprocess(
self,
x,
):
def generate_postprocess(self, x):
"""Convert integer token output to strings for generation.
This method reverses `generate_preprocess()`, by first removing all
Expand Down
20 changes: 20 additions & 0 deletions keras_nlp/models/gemma/gemma_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,26 @@ def wrapper(*args, **kwargs):
# We should immediately abort and output the prompt.
self.assertEqual(prompt, output)

def test_multitoken_stopping(self):
causal_lm = GemmaCausalLM(**self.init_kwargs)
call_with_cache = causal_lm.call_with_cache

def wrapper(*args, **kwargs):
"""Modify output logits to always favor end_token_id"""
logits, hidden_states, cache = call_with_cache(*args, **kwargs)
index = self.preprocessor.tokenizer.end_token_id
update = ops.ones_like(logits)[:, :, index] * 1.0e9
update = ops.expand_dims(update, axis=-1)
logits = ops.slice_update(logits, (0, 0, index), update)
return logits, hidden_states, cache

with patch.object(causal_lm, "call_with_cache", wraps=wrapper):
prompt = ["the quick brown fox", "the quick"]

output = causal_lm.generate(prompt, stop_token_ids=(3,))
# We should immediately abort and output the prompt.
self.assertEqual(prompt, output)

def test_generate_compilation(self):
causal_lm = GemmaCausalLM(**self.init_kwargs)
# Assert we do not recompile with successive calls.
Expand Down
42 changes: 31 additions & 11 deletions keras_nlp/models/generative_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import itertools
from functools import partial

import tensorflow as tf
import tree
Expand Down Expand Up @@ -64,10 +65,10 @@ def make_generate_function(self):

def wrapped_generate_function(
inputs,
end_token_id=None,
stop_token_ids=None,
):
with torch.no_grad():
return self.generate_step(inputs, end_token_id)
return self.generate_step(inputs, stop_token_ids)

self.generate_function = wrapped_generate_function
elif config.backend() == "tensorflow" and not self.run_eagerly:
Expand All @@ -80,8 +81,8 @@ def wrapped_generate_function(
elif config.backend() == "jax" and not self.run_eagerly:
import jax

@jax.jit
def compiled_generate_function(inputs, end_token_id, state):
@partial(jax.jit, static_argnames=["stop_token_ids"])
def compiled_generate_function(inputs, stop_token_ids, state):
(
sampler_variables,
trainable_variables,
Expand All @@ -94,7 +95,7 @@ def compiled_generate_function(inputs, end_token_id, state):
)

with keras.StatelessScope(state_mapping=mapping) as scope:
outputs = self.generate_step(inputs, end_token_id)
outputs = self.generate_step(inputs, stop_token_ids)

# Get updated sampler variables from the stateless scope.
sampler_variables = []
Expand All @@ -105,8 +106,11 @@ def compiled_generate_function(inputs, end_token_id, state):

def wrapped_generate_function(
inputs,
end_token_id=None,
stop_token_ids=None,
):
if isinstance(stop_token_ids, list):
stop_token_ids = tuple(stop_token_ids)

# Create an explicit tuple of all variable state.
state = (
self._sampler.variables,
Expand All @@ -118,7 +122,7 @@ def wrapped_generate_function(
inputs = tree.map_structure(ops.convert_to_tensor, inputs)
outputs, sampler_variables = compiled_generate_function(
inputs,
end_token_id,
stop_token_ids,
state,
)
# Only assign the sampler variables (random seeds), as other
Expand Down Expand Up @@ -206,6 +210,7 @@ def generate(
self,
inputs,
max_length=None,
stop_token_ids=None,
):
"""Generate text given prompt `inputs`.
Expand Down Expand Up @@ -234,23 +239,38 @@ def generate(
`preprocessor`. If `preprocessor` is `None`, `inputs` should be
should be padded to the desired maximum length and this argument
will be ignored.
stop_token_ids: Optional. `None`, "auto", or tuple of token ids. Defaults
to "auto" which uses the `preprocessor.tokenizer.end_token_id`.
Not specifying a processor will produce an error. None stops
generation after generating `max_length` tokens. You may also
specify a list of token id's the model should stop on. Note that
sequences of tokens will each be interpreted as a stop token,
multi-token stop sequences are not supported.
"""
# Setup our three main passes.
# 1. Optionally preprocessing strings to dense integer tensors.
# 2. Generate new tokens via a compiled function on dense tensors.
# 3. Optionally postprocess dense integer tensors back to string.
generate_function = self.make_generate_function()
end_token_id = None
if self.preprocessor is not None:
end_token_id = self.preprocessor.tokenizer.end_token_id

if self.preprocessor is None and stop_token_ids == "auto":
raise ValueError(
'A `preprocessor` must be attached to the model if `stop_token_ids="auto"`. '
"Currently `preprocessor=None`. To call `generate()` with preprocessing "
"detached, either pass `stop_tokens_ids=None` to always generate until "
"`max_length` or pass a tuple of token ids that should terminate generation "
"as `stop_tokens_ids`."
)
elif stop_token_ids == "auto":
stop_token_ids = [self.preprocessor.tokenizer.end_token_id]

def preprocess(x):
return self.preprocessor.generate_preprocess(
x, sequence_length=max_length
)

def generate(x):
return generate_function(x, end_token_id=end_token_id)
return generate_function(x, stop_token_ids=stop_token_ids)

def postprocess(x):
return self.preprocessor.generate_postprocess(x)
Expand Down
19 changes: 10 additions & 9 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.tensor_utils import any_equal


@keras_nlp_export("keras_nlp.models.GPT2CausalLM")
Expand Down Expand Up @@ -251,7 +252,7 @@ def _build_cache(self, token_ids):
def generate_step(
self,
inputs,
end_token_id=None,
stop_token_ids=None,
):
"""A compilable generation function for a single batch of inputs.
Expand All @@ -262,8 +263,8 @@ def generate_step(
Args:
inputs: A dictionary with two keys `"token_ids"` and
`"padding_mask"` and batched tensor values.
end_token_id: The id of the end token to stop on. If all
sequences have produced a new `end_token_id`, generation
stop_token_ids: List of id's of end token's to stop on. If all
sequences have produced a new stop token, generation
will stop.
"""
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
Expand Down Expand Up @@ -296,19 +297,19 @@ def next(prompt, cache, index):
cache=cache,
index=index,
mask=padding_mask,
end_token_id=end_token_id,
stop_token_ids=stop_token_ids,
hidden_states=hidden_states,
model=self,
)

# Compute an output padding mask with the token ids we updated.
if end_token_id is not None:
# Build a mask of `end_token_id` locations not in the original
if stop_token_ids is not None:
# Build a mask of stop tokens locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = ops.logical_and(
ops.equal(token_ids, end_token_id),
ops.logical_not(padding_mask),
end_locations = any_equal(
token_ids, stop_token_ids, ops.logical_not(padding_mask)
)

end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after end_locations.
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
Expand Down
Loading

0 comments on commit 6e946e2

Please sign in to comment.