Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLMRobertaTokenizerFast producing wrong tokenized output #9637

Closed
sstojanoska opened this issue Jan 16, 2021 · 8 comments
Closed

XLMRobertaTokenizerFast producing wrong tokenized output #9637

sstojanoska opened this issue Jan 16, 2021 · 8 comments

Comments

@sstojanoska
Copy link

Environment info

  • transformers` version: 4.2.1
  • Platform: Linux-4.19.112+-x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.6.9
  • PyTorch version (GPU?): 1.7.0+cu101 (False)
  • Tensorflow version (GPU?): 2.4.0 (False)
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help

@mfuntowicz
@stefan-it

Information

Model I am using is XLM-RoBERTa.
The problem arises when using XLMRobertaTokenizerFast tokenizer.

The tasks I am working on is token-classification. In order to align the labels with the sub-word units I have used the code snippet provided here: https://huggingface.co/transformers/custom_datasets.html [ Fine-tuning with custom datasets/Token Classification with W-NUT Emerging Entities ].

When trying to align the labels with the encodings, it throws: "ValueError: NumPy boolean array indexing assignment cannot assign X input values to the Y output values where the mask is true."

This behavior is due to tokenizing punctuation. Moreover comma ( ' , ' ) gets tokenized into '' and ',' ( having offset values (0,1) ) Similar behavior happens with dot. However, some other punctuation marks are producing only one token (i.g. ' : ' -> ':').
In addition, the offset_mapping value for ':' is different in different sentences resulting either in (0,0) or (0,3) tuple. The problem is that padding tokens have offset tuple with values (0,0) which are excluded from alignment, but in this case I have to preserve the punctuation since it is POS tagging problem.

To reproduce

print("Token: {}  Offset_mapping: {}".format(train_encodings[338].tokens[67], train_encodings[338].offsets[67]))
# Token: ▁...  Offset_mapping: (0, 0)
print("Token: {}  Offset_mapping: {}".format(train_encodings[20].tokens[2], train_encodings[20].offsets[2]))
# Token: ▁...  Offset_mapping: (0, 3)

Moreover, although I fixed this issue by writing my own masks, I found new issue: the blank space which denotes start of the word is tokenized as separate token instead of being together with the starting sub-token.

To reproduce

tokenizer = XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-base")
model= XLMRobertaForTokenClassification.from_pretrained("xlm-roberta-base")
s = "Je često kritizirao vladu ."
print(tokenizer.tokenize(s)) 
# output: ['▁Je', '▁često', '▁krit', 'izira', 'o', '▁', 'vlad', 'u', '▁', '.']

Expected behavior

  1. Punctuation marks should be consistently tokenized and having offset values different from padding tokens.
  2. The first sub-word token should be with preceding blank space everywhere.
@n1t0
Copy link
Member

n1t0 commented Feb 1, 2021

There are two different subjects being discussed here:

  • The tokenization behavior: how punctuation is tokenized, or how the blank spaces are separated from the next token. This is expected behavior and just describes the way this tokenizer (XLMRoberta) works.
  • The offset mappings, which as described here are wrong in some cases. These need to be fixed, and I am going to describe a bit more the problem and how we are going to solve it below.

Cause

This bug in offset mapping actually affects all the fast tokenizers converted from sentencepiece. During the pre-tokenization step, we first split everything on whitespaces (WhitespaceSplit pre-tokenizer), and in a second step, we add the character in front of each word (Metaspace pre-tokenizer). This process is accurate in terms of tokenization, but it makes the offset tracking very difficult:

  • All the whitespaces get removed, so we won't have any token pointing back to them.
  • We add a "new" in front of each word, so these tokens actually point back to the beginning of each word: the first character.

How to fix it

The initial idea of using the WhitespaceSplit in a first step was simply to deduplicate the whitespaces but since it leads to loss of information we'll replace it with the following process:

  • Normalization step that replaces groups of whitespaces with a single one, effectively mapping the single whitespace to the group in the original input.
  • Pretokenization step: we just keep the Metaspace pre-tokenizer.

In order to fix this we need to:

  1. Update all the tokenizer.json files on the hub, and it will be compatible with any version of transformers since we introduced these fast tokenizers (3.5.0+).
  2. Update all the conversion steps in transformers to avoid creating more buggy tokenizers.

@n1t0
Copy link
Member

n1t0 commented Feb 1, 2021

List of updated tokenizers:

These can't be fixed this way:

The following will need a new version of transformers with a bugfix in tokenizers. We'll need to find a way to rely on the new tokenizer.json version only on versions of transformers that include this bugfix, as it would break all the previous ones.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@LysandreJik
Copy link
Member

Unstale

@github-actions
Copy link

github-actions bot commented May 9, 2021

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@yaysummeriscoming
Copy link

Any update on this one?

@noah-rush
Copy link

Bump

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants