Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Breaking-change behavior in BERT tokenizer when stripping accents #2917

Closed
2 of 4 tasks
bryant1410 opened this issue Feb 19, 2020 · 9 comments
Closed
2 of 4 tasks

Breaking-change behavior in BERT tokenizer when stripping accents #2917

bryant1410 opened this issue Feb 19, 2020 · 9 comments
Assignees
Labels
Core: Tokenization Internals of the library; Tokenization. wontfix

Comments

@bryant1410
Copy link
Contributor

🐛 Bug

Information

Model I am using (Bert, XLNet ...): Bert (could happen with other ones, don't know)

Language I am using the model on (English, Chinese ...): English

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

from transformers import AutoTokenizer

pretrained_model_name = "bert-base-cased"

fast_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
slow_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name, use_fast=False)

text = "naïve"

assert fast_tokenizer.encode(text) == slow_tokenizer.encode(text)

With the slow, it only strips accents if lowercase is enabled (maybe a bug?):

token = self._run_strip_accents(token)

With the fast one, it'd never strip accents:

https://github.com/huggingface/tokenizers/blob/python-v0.5.0/bindings/python/tokenizers/implementations/bert_wordpiece.py#L23

BertWordPieceTokenizer(
vocab_file=vocab_file,
add_special_tokens=add_special_tokens,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
handle_chinese_chars=tokenize_chinese_chars,
lowercase=do_lower_case,
),

I'd be cool to have that flag also, in both tokenizers.

Finally, this warning seems odd for the simple code from above:

>>> assert fast_tokenizer.encode(text) == slow_tokenizer.encode(text)
Disabled padding because no padding token set (pad_token: [PAD], pad_token_id: 0).
To remove this error, you can add a new pad token and then resize model embedding:
	tokenizer.pad_token = '<PAD>'
	model.resize_token_embeddings(len(tokenizer))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AssertionError

Maybe here the if pad_to_max_length should be nesting the rest of the if?

if pad_to_max_length and (pad_token and pad_token_id >= 0):
tokenizer.enable_padding(
max_length=None,
direction=padding_side,
pad_id=pad_token_id,
pad_type_id=pad_token_type_id,
pad_token=pad_token,
)
else:
logger.warning(
"Disabled padding because no padding token set (pad_token: {}, pad_token_id: {}).\n"
"To remove this error, you can add a new pad token and then resize model embedding:\n"
"\ttokenizer.pad_token = '<PAD>'\n\tmodel.resize_token_embeddings(len(tokenizer))".format(
pad_token, pad_token_id
)
)

Didn't check in the other transformer models.

Expected behavior

  1. The 2 tokenizer outputs (slow and fast) should be the same.
  2. The tokenizers should allow you to choose if to strip accents or not.
  3. That warning shouldn't appear, IMHO.

Environment info

  • transformers version: 2.5.0
  • Platform: Linux-4.15.0-76-generic-x86_64-with-debian-buster-sid
  • Python version: 3.7.4
  • PyTorch version (GPU?): 1.4.0 (True)
  • Tensorflow version (GPU?): 2.0.0 (False)
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No
@ZG2017
Copy link

ZG2017 commented Feb 20, 2020

Yeah, I found the same problem in my code. The "encode" won't add padding even "pad_to_max_length = True".

@mfuntowicz mfuntowicz self-assigned this Feb 20, 2020
@BramVanroy BramVanroy added the Core: Tokenization Internals of the library; Tokenization. label Feb 20, 2020
@mfuntowicz
Copy link
Member

HI @bryant1410,

Thanks for reporting the issue. The parameter strip_accents was indeed enabled on BertTokenizerFast.

I've a PR exposing the missing parameters #2921, it will land soon on master and will be included in the first maintenance release of 2.5

@bryant1410
Copy link
Contributor Author

I see, thanks! There's an incompatibility still though, which is that you can choose if to strip accents in the fast tokenizers but you can't control that in the previous tokenizers. I believe this should be fixed as well.

And be aware that, IIRC, this is still a breaking change, because in the previous tokenizers you would get stipped accents by default in one way but now it seems to behave in a different way by default.

I don't know if this also the case for the other params added in #2921, and for other models apart from BERT.

@stale
Copy link

stale bot commented Apr 22, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Apr 22, 2020
@bryant1410
Copy link
Contributor Author

Please don't close it as this is an important issue.

@stale stale bot removed the wontfix label Apr 22, 2020
@julien-c
Copy link
Member

Same one reported by @stefan-it, @n1t0 ?

@n1t0
Copy link
Member

n1t0 commented Apr 23, 2020

Yes same one. Stripping accents is happening only when do_lower_case=True for slow tokenizers, and there is no way at the moment to change this behavior.

We can probably add an explicit option for this on slow tokenizers, and specify the default values in the configs.

@stale
Copy link

stale bot commented Jun 22, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Jun 22, 2020
@stale stale bot closed this as completed Jul 1, 2020
@pauli31
Copy link

pauli31 commented Jul 13, 2020

Don't close it!! I want to have control of striping accents when tokenizing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Tokenization Internals of the library; Tokenization. wontfix
Projects
None yet
Development

No branches or pull requests

7 participants