Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPRReaderTokenizer does not generate the attention_mask properly #9555

Closed
mkserge opened this issue Jan 13, 2021 · 2 comments
Closed

DPRReaderTokenizer does not generate the attention_mask properly #9555

mkserge opened this issue Jan 13, 2021 · 2 comments

Comments

@mkserge
Copy link
Contributor

mkserge commented Jan 13, 2021

Hello,

It seems like the DPRReaderTokenizer does not generate the attention_mask properly.

Steps to reproduce on the master branch

(venv) sergey_mkrtchyan test (master) $ python
Python 3.8.6 (v3.8.6:db455296be, Sep 23 2020, 13:31:39)
[Clang 6.0 (clang-600.0.57)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> from transformers import DPRReaderTokenizer, DPRReader
>>> tokenizer = DPRReaderTokenizer.from_pretrained('facebook/dpr-reader-single-nq-base')
>>> model = DPRReader.from_pretrained('facebook/dpr-reader-single-nq-base')
>>> encoded_inputs = tokenizer(questions="What is love ?",
...                            titles="Haddaway",
...                            texts="What Is Love is a song recorded by the artist Haddaway",
...                            padding=True,
...                            return_tensors='pt')
>>> encoded_inputs
{'input_ids': tensor([[ 101, 2054, 2003, 2293, 1029,  102, 2018, 2850, 4576,  102, 2054, 2003,
         2293, 2003, 1037, 2299, 2680, 2011, 1996, 3063, 2018, 2850, 4576]]), 'attention_mask': tensor([True])}

Notice the attention_mask above is incorrect. It should have the same shape as the input_ids tensor.

Environment info

  • transformers version: 4.2.0dev0
  • Platform: macOS-10.15.7-x86_64-i386-64bit
  • Python version: 3.8.6
  • PyTorch version (GPU?): 1.7.1 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help

Git blame says @lhoestq and @LysandreJik might be able to help :)

I believe the issue is in this part of the code

attention_mask = [input_ids != self.pad_token_id for input_ids in encoded_inputs["input_ids"]]

(same thing for the fast tokenizer)

I fixed it locally by replacing the above line with

attention_mask = []
    for input_ids in encoded_inputs["input_ids"]:
        attention_mask.append([int(input_id != self.pad_token_id) for input_id in input_ids])

I am happy to submit a PR if that looks reasonable to you.

@LysandreJik
Copy link
Member

Indeed, it doesn't! We would gladly welcome a PR!

@LysandreJik
Copy link
Member

Closed by #9663 :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants