Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does the 'bad_words_ids' argument in the "generate function" works? #14206

Closed
alvinwatner opened this issue Oct 29, 2021 · 13 comments
Closed

Does the 'bad_words_ids' argument in the "generate function" works? #14206

alvinwatner opened this issue Oct 29, 2021 · 13 comments

Comments

@alvinwatner
Copy link

Environment info

  • transformers version: 4.12.0
  • Platform: Linux-5.4.104+-x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.7.12
  • PyTorch version (GPU?): 1.9.0+cu111 (False)
  • Tensorflow version (GPU?): 2.6.0 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help

Information

I attempt to evaluate whether the bad_words_ids argument that available in the generate() function works or not. However, based on the steps that I described in below section, it doesn't works.

To reproduce

Below is the steps I used to evaluate:

  1. Run the script without bad_words_ids being specified and set_seed to get deterministic output.
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, set_seed

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

set_seed(0)

input_context = "My cute dog"
input_ids = tokenizer(input_context, return_tensors="pt").input_ids
outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True)
print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))

Output:
Generated: My cute dog, when it died, had taken my entire life to save the life that had been

  1. Re-run the script, but with bad_words_ids being specified. I select the word "entire" and "save" taken from the previously generated sequence. However, both words still appear in the output sequence with no difference as the previous one. Below is the script with the following output.
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, set_seed

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

set_seed(0)

input_context = "My cute dog"
# get tokens of words that should not be generated
bad_words_ids = [tokenizer(bad_word).input_ids for bad_word in ["entire", "save"]]
# encode input context
input_ids = tokenizer(input_context, return_tensors="pt").input_ids
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids)
print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))

Output:
Generated: My cute dog, when it died, had taken my entire life to save the life that had been

To reproduce in Google Colab:

https://colab.research.google.com/drive/1P4ruLhFstbal1qqXbjuv-kM7yMYY-S1E?usp=sharing

Expected behavior

I expect the word "entire" and "save" to not be included in the output sequence after I run step (2) in above section.

@alvinwatner alvinwatner changed the title Does the 'bad_words_ids' argument in generate **function** works? Does the 'bad_words_ids' argument in the "generate function" works? Oct 30, 2021
@qqaatw
Copy link
Contributor

qqaatw commented Oct 30, 2021

Hey @alvinwatner,

To prevent bad words from occurring in the middle of generated texts, you'll need to add a prefix space to every bad word so that the tokenized bad words e.g. save will be ['Ġsave'] instead of ['save'], which matches GPT2's outputs.

This can be done by setting add_prefix_space=True in the kwargs of from_pretrained.

model = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
tokenizer = AutoTokenizer.from_pretrained("gpt2", add_prefix_space=True)

set_seed(0)

input_context = "My cute dog"
# get tokens of words that should not be generated
bad_words_ids = tokenizer(["entire", "save"]).input_ids
# encode input context
input_ids = tokenizer(input_context, return_tensors="pt").input_ids
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids)
print("Generated:", tokenizer.decode(outputs["sequences"][0], skip_special_tokens=True))

Output:

Generated:  My cute dog, when it died, had taken my hand out of my pants and said "I

@alvinwatner
Copy link
Author

alvinwatner commented Oct 30, 2021

Thank you @qqaatw for pointing that out. Just to inform that this example script doesn't work and outdated.

@giladpn
Copy link

giladpn commented Nov 15, 2021

Hi @qqaatw

Thanks in advance: I am trying to do something very similar but with T5 (either t5-base or t5-large) as the model instead of GPT2. My "bad words" are simply being ignored so it's a very similar problem. Can you advise? Am I missing some configuration that would be relevant for T5?

I am running code similar to the above but using T5ForConditionalGeneration with no luck. Any help appreciated!

@alvinwatner
Copy link
Author

Hi @giladpn and @qqaatw. I found a thing with this bad_words functionality and I'm not sure if this is normal behaviour or not.

For a word that tokenized into multiple tokens, the generate function will only replace the final token while the earlier tokens still remained in the output sequence.

For e.g., the word " tester", with prefix space, tokenized into ---> ["Ġt", "ester"], with the following ids --> [256, 7834]), the output sequence will maintain the earlier tokens ("256") and only replace the final token ("7834"). Other instance, the word " traceroute" with prefix space tokenized into ---> 'Ġtr', 'acer', 'oute', with the following ids --> [491, 11736, 13192], the output sequnce will maintain the earlier tokens ("491, 11736") and only replace the final token ("13192").

@alvinwatner alvinwatner reopened this Nov 16, 2021
@qqaatw
Copy link
Contributor

qqaatw commented Nov 19, 2021

Hi @giladpn,

Can you provide a minimal but reproducible code so that I can see where the problem is?

Thanks.

@qqaatw
Copy link
Contributor

qqaatw commented Nov 19, 2021

Edited: Indeed, if a word is tokenized into multiple tokens, the first token will still present on the generated sequence. I'll take some time to deal with it.

@alvinwatner, what's the input text that you supply to the model?

@giladpn
Copy link

giladpn commented Nov 19, 2021

Hi @qqaatw

I am trying to use T5 instead of GPT-2 in your example. Here is the code I am using, which is copy-pasted from your code example above with a few minimal changes:

  • changed gpt2 to t5-base
  • changed AutoModelForCausalLM to T5ForConditionalGeneration

The code now generates a sentence successfully but ignores the "bad word" I put in ("dude"). The generated sentence is:

"My cute cat is the sweetest little dude in the world. My cute dog is"

Here is the code, what am I doing wrong? Thank you!

from transformers import AutoTokenizer, AutoModelForCausalLM, T5ForConditionalGeneration, set_seed
model = T5ForConditionalGeneration.from_pretrained("t5-base", return_dict_in_generate=True)
tokenizer = AutoTokenizer.from_pretrained("t5-base", add_prefix_space=True)

set_seed(0)

input_context = "My cute dog"

# get tokens of words that should not be generated
bad_words_ids = tokenizer(["dude"]).input_ids

# encode input context
input_ids = tokenizer(input_context, return_tensors="pt").input_ids
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids)
print("Generated:", tokenizer.decode(outputs["sequences"][0], skip_special_tokens=True))

@qqaatw
Copy link
Contributor

qqaatw commented Nov 19, 2021

@giladpn, thanks for providing the code. Can you add add_special_tokens=False to tokenizer.__call__() and see if the problem is solved? Like so:

bad_words_ids = tokenizer(["dude"], add_special_tokens=False).input_ids

@giladpn
Copy link

giladpn commented Nov 19, 2021

@qqaatw Yes! It works now. Many thanks! Much appreciated.

@alvinwatner
Copy link
Author

alvinwatner commented Nov 22, 2021

Edited: Indeed, if a word is tokenized into multiple tokens, the first token will still present on the generated sequence. I'll take some time to deal with it.

@alvinwatner, what's the input text that you supply to the model?

Hi, sorry for the late reply. I have been busy working with my paper lately. And I had eventually created my own script, not optimized well enough, but seems able to deal with those issues.

  • Here, generation_banned_words.py if you want to take a look link.

  • Unfortunately, I only managed to attach it to greedy_search due to time constraint. Here is how it looks like link. Also, since my script only requires 'input_ids' and 'next_tokens' (that exist in every sampling method) and the 'sorted_next_token_indices (that is just the topk from the next_tokens_scores), I assume that it should not be too difficult to embed this to other sampling methods. Why we need 'sorted_next_token_indices'? I could explain further, but in short, at every timestep, if the chosen token (argmax initially) satisfied the banned_words ids, it will be replaced by other token that has the next highest probs after the chosen token (for e.g., sorted_next_token_indices = [5, 9, ..., vocab_size], banned_words_ids = [5]. Then, we chose the next highest after 5, which is 9).

  • Here is a glimpse of usage I made in colab link

ps : sorry for the sphagetty code :(

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@musitafa0032
Copy link

It seems like this function did not work for Chinese bart, but Chinese bart use bert tokenizer not bart tokenizer, don't know if this affect? Anyone knows how to make it work in Chinese bart? Thank you

@musitafa0032
Copy link

Interesting, I just figure it out. For Chinese bart, you only need the one token id to make it work out, because there is no suffix in Chinese character, so if you use tokenizer to get bad word ids, it will return something like [[101, 704, 102]], but the 101 and 102 represent [CLS] and [SEP], you only need 704 id.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants