Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong offsets_mapping in T5TokenizerFast #9633

Closed
zorikg opened this issue Jan 16, 2021 · 8 comments
Closed

Wrong offsets_mapping in T5TokenizerFast #9633

zorikg opened this issue Jan 16, 2021 · 8 comments

Comments

@zorikg
Copy link

zorikg commented Jan 16, 2021

Environment info

  • transformers version: 4.2.1
  • Platform: Linux-4.9.0-14-amd64-x86_64-with-debian-9.13
  • Python version: 3.6.10
  • PyTorch version (GPU?): 1.7.0 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help @patrickvonplaten, @mfuntowicz

Information

Model I am using: T5

To reproduce

See comments in the code snippet.

from transformers import T5TokenizerFast


def test_offset_mapping():
    """This test fails and therefore we know that there is a bug in offset_mapping mechanism.
        We try to tokenize the sentence 'This is a test sentence' and notice to issues:

        1. The tokenizer tokenizes it to ['This', 'is', '', 'a', 'test', 'sentence']
            which means that it has redundant empty string in position 2.
        2. The offset mapping maps to ['This', 'is', 'a', 'a', 'test', 'sentence']
            replacing the empty string with redundant 'a'.

    """
    tokenizer = T5TokenizerFast.from_pretrained('google/t5-v1_1-base')

    s = "This is a test sentence"
    tokenized = tokenizer(s, return_offsets_mapping=True)
    
    decoded_tokens, tokens_from_offset_mapping = [], []
    for token_index, offset_mapping in enumerate(tokenized['offset_mapping']):
        decoded_token = tokenizer.decode(tokenized['input_ids'][token_index])
        if decoded_token != tokenizer.eos_token:
            decoded_tokens.append(decoded_token)
            tokens_from_offset_mapping.append(s[offset_mapping[0]:offset_mapping[1]])

    error_msg = f"Wrong offset mapping for '{s}'! \n" \
                f"Maps to:          {tokens_from_offset_mapping}\n" \
                f"Instead of:       {decoded_tokens}"
    assert decoded_tokens == tokens_from_offset_mapping, error_msg


if __name__ == "__main__":
    test_offset_mapping()

Expected behavior

AssertionError: Wrong offset mapping for 'This is a test sentence'! 
Maps to:          ['This', 'is', 'a', 'a', 'test', 'sentence']
Instead of:       ['This', 'is', '', 'a', 'test', 'sentence']
@LysandreJik
Copy link
Member

@patrickvonplaten @n1t0 do you have any advice on this? The T5 tokenizer tokenizes the sentence as follows:

['▁This', '▁is', '▁', 'a', '▁test', '▁sentence']

Unfortunately the offset mapping point to both '▁' and 'a' being at (8, 9), as the following suggests:

'offset_mapping': [(0, 4), (5, 7), (8, 9), (8, 9), (10, 14), (15, 23), (0, 0)]
                                    ^---- & ^---- here 

How should one map this encoding back to the initial sequence?

@zorikg
Copy link
Author

zorikg commented Feb 1, 2021

@patrickvonplaten @n1t0 - did you have a chance to look at this?
Thanks!

@n1t0
Copy link
Member

n1t0 commented Feb 3, 2021

Hi @zorikg! Thank you for reporting this issue. This is related to #9637 concerning the offset mappings bug.

The fix for this bug is tricky to deploy, but we are working on it, and I expect it to be available in the coming weeks.

@zorikg
Copy link
Author

zorikg commented Mar 6, 2021

Thanks @n1t0, I wondered if there have been any progress on this? Any expectation for when the fix will be avail? Thanks!

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@n1t0
Copy link
Member

n1t0 commented Apr 15, 2021

@zorikg Using the last few versions of transformers, you can instantiate your tokenizer as follow:

tokenizer = T5TokenizerFast.from_pretrained('google/t5-v1_1-base', from_slow=True)

This will force the conversion from the slow tokenizer, thus using the fixed version.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Oxi84
Copy link

Oxi84 commented Nov 29, 2021

I am getting some difference between these 2 tokenizers is this solved?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants