Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Reformer to use caching during generation #8252

Closed

Conversation

guillaume-be
Copy link
Contributor

What does this PR do?

The current reformer implementation supports caching of buckets and states, but this is not used during generation. Running a generation example in debugging mode, such as

from transformers import ReformerModelWithLMHead, ReformerTokenizer

model = ReformerModelWithLMHead.from_pretrained("google/reformer-crime-and-punishment").cuda()
tok = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punishment")
output = tok.decode(
    model.generate(tok.encode("Notwithstanding", return_tensors="pt").cuda(),
                   do_sample=True,
                   temperature=0.7,
                   max_length=100,
                   use_cache=True)[0])

One can see that the past_buckets_states passed to the attention are always None (at

past_buckets_states=None,
)

This is because the name of the past states for the reformer are neither past_key_values or mems.
This PR adds the name of the past states to the generation past allocation.

Generally, it may make sense to harmonize the past value for all models, so that the generate function generalizes better

Who can review?

Text Generation: @patrickvonplaten, @TevenLeScao
Reformer: @patrickvonplaten

@guillaume-be guillaume-be changed the title Added past_buckets_states to possible output cached states Updated Reformer to use caching during generation Nov 3, 2020
@patrickvonplaten
Copy link
Contributor

Great catch!

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Nov 3, 2020

Actually, we would have to add in two spots of this generate version. Considering that we will merge the big generate refactor today, I just added your fix quickly here: 12b54ec

Mentioned your PR at the fix - hope it's ok for you to close this PR to avoid any more merge conflicts.

Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants