You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It seems like the DPRReaderTokenizer does not generate the attention_mask properly.
Steps to reproduce on the master branch
(venv) sergey_mkrtchyan test (master) $ python
Python 3.8.6 (v3.8.6:db455296be, Sep 23 2020, 13:31:39)
[Clang 6.0 (clang-600.0.57)] on darwin
Type "help", "copyright", "credits" or "license"for more information.
>>> from transformers import DPRReaderTokenizer, DPRReader
>>> tokenizer = DPRReaderTokenizer.from_pretrained('facebook/dpr-reader-single-nq-base')
>>> model = DPRReader.from_pretrained('facebook/dpr-reader-single-nq-base')
>>> encoded_inputs = tokenizer(questions="What is love ?",
... titles="Haddaway",
... texts="What Is Love is a song recorded by the artist Haddaway",
... padding=True,
... return_tensors='pt')
>>> encoded_inputs
{'input_ids': tensor([[ 101, 2054, 2003, 2293, 1029, 102, 2018, 2850, 4576, 102, 2054, 2003,
2293, 2003, 1037, 2299, 2680, 2011, 1996, 3063, 2018, 2850, 4576]]), 'attention_mask': tensor([True])}
Notice the attention_mask above is incorrect. It should have the same shape as the input_ids tensor.
Environment info
transformers version: 4.2.0dev0
Platform: macOS-10.15.7-x86_64-i386-64bit
Python version: 3.8.6
PyTorch version (GPU?): 1.7.1 (False)
Tensorflow version (GPU?): not installed (NA)
Using GPU in script?: No
Using distributed or parallel set-up in script?: No
Hello,
It seems like the DPRReaderTokenizer does not generate the
attention_mask
properly.Steps to reproduce on the master branch
Notice the
attention_mask
above is incorrect. It should have the same shape as theinput_ids
tensor.Environment info
transformers
version: 4.2.0dev0Who can help
Git blame says @lhoestq and @LysandreJik might be able to help :)
I believe the issue is in this part of the code
transformers/src/transformers/models/dpr/tokenization_dpr.py
Line 254 in 5f67210
(same thing for the fast tokenizer)
I fixed it locally by replacing the above line with
I am happy to submit a PR if that looks reasonable to you.
The text was updated successfully, but these errors were encountered: