Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] padding tokens are also masked in DataCollatorForLanguageModeling #11155

Closed
1 of 4 tasks
ldong87 opened this issue Apr 9, 2021 · 2 comments · Fixed by #11163
Closed
1 of 4 tasks

[BUG] padding tokens are also masked in DataCollatorForLanguageModeling #11155

ldong87 opened this issue Apr 9, 2021 · 2 comments · Fixed by #11163

Comments

@ldong87
Copy link

ldong87 commented Apr 9, 2021

Environment info

  • transformers version: 4.3.2
  • Platform: Linux
  • Python version: 3.6
  • PyTorch version (GPU?): 1.7.1 GPU
  • Tensorflow version (GPU?): N/A
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Sagemaker distributed data parallel

Who can help

@sgugger

Information

Model I am using (Bert, XLNet ...): All models that use DataCollatorForLanguageModeling.

The bug is introduced in this PR.

3 lines (241-243) are removed by mistake from this line.

Now padding tokens are also masked in MLM.

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

from transformers import DataCollatorForLanguageModeling
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('albert-base-v2')
data_collator = DataCollatorForLanguageModeling(
                tokenizer=tokenizer, mlm=True, mlm_probability=0.15
            )
tok = tokenizer('hello how are you!', add_special_tokens=True, truncation=True, max_length=256, padding='max_length')
data_collator([tok['input_ids']])

From the output you can easily see that the padding tokens are masked. Add back the three removed lines fix this bug.

Expected behavior

padding token is not supposed to be mask-able in MLM.

@ldong87 ldong87 changed the title padding tokens are also masked in DataCollatorForLanguageModeling [BUG] padding tokens are also masked in DataCollatorForLanguageModeling Apr 9, 2021
@woong97
Copy link

woong97 commented Apr 9, 2021

I have similar issues.

"pad" token is not masked when I run bert-base-uncased model , but "pad" token can be masked when I run albert-base-v2

In examples/language-modeliing/run_mlm.py, I try to call tokenizer.get_special_tokens_mask.

print(tokenizer.get_special_tokens_mask([0, 100, 101, 102, 2, 3, 4], already_has_special_tokens=True))

Interestingly, "get_special_tokens_mask" function is called from "class PreTrainedTokenizerBase" when I run bert-base-uncased, but "get_special_tokens_mask" function is called from "class AlbertTokenizerFast" whenn I run albert-base-v2.

In PretrainedToknizerBase class,

def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
   all_special_ids = self.all_special_ids  # cache the property
   special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]

   return special_tokens_mask

However in AlbertTokenizerFast class,

def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
    if already_has_special_tokens:
        if token_ids_1 is not None:
            raise ValueError(
                "You should not supply a second sequence if the provided sequence of "
                "ids is already formatted with special tokens for the model."
            )
        return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))

    if token_ids_1 is not None:
        return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
    return [1] + ([0] * len(token_ids_0)) + [1]

=> These two functions are different. Thus when I use bert, all_special_ids( it contains cls, sep, pad id) are ids which cannot be masked. But when i use albert, only cls, sep ids cannot be masked. Thus pad token can be masked when i use albert.

I don't know why the functions are called from different class when I run bert-base-uncased or albert.
Do you know why??

And is it correct that pad token will be masked in albert model??

[bert command]

%  python run_mlm.py --model_name_or_path bert-base-uncased --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --do_train --do_eval --output_dir ./tmp/test-mlm --line_by_line

[albert command]

%  python run_mlm.py --model_name_or_path albert-base-v2 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --do_train --do_eval --output_dir ./tmp/test-mlm --line_by_line

@sgugger
Copy link
Collaborator

sgugger commented Apr 9, 2021

Thanks for reporting! This is actually a bug in the get_special_tokens_mask method of most tokenizers. I will push a fix soon. In the meantime, you can workaround the problem by passing the special_token_mask the tokenizer returns to the data collator (which will actually be faster since it will avoid being recomputed):

tokenizer = AutoTokenizer.from_pretrained('albert-base-v2')
data_collator = DataCollatorForLanguageModeling(
                tokenizer=tokenizer, mlm=True, mlm_probability=0.15
            )
tok = tokenizer('hello how are you!',return_special_tokens_mask=True, truncation=True, max_length=256, padding='max_length')
data_collator([tok])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants