-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Recurrent Attention: standalone machine translation example #11421
Recurrent Attention: standalone machine translation example #11421
Conversation
…ttention_standalone_example
batch_size=BATCH_SIZE, | ||
epochs=EPOCHS, | ||
validation_data=( | ||
[target_seqs_train[:, :-1], input_seqs_train], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use validation data!
@fchollet This is a very neat and thorough PR. Please review and discuss what parts needs to be moved into the Keras API and what should stay in the example. |
return K.max(K.stack([x_1, x_2], axis=-1), axis=-1, keepdims=False) | ||
|
||
h2 = TimeDistributed(Lambda(dense_maxout))(concatenate([h1, y_emb])) | ||
y_pred = TimeDistributed(Dense(target_tokenizer.num_words))(h2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Softmax missing!
Regarding TODO(4) in the docs: This diff clarifies the changes needed to improve the efficiency of the attention mechanism. It is a little bit less intuitive, why I left it out of this PR. It boosts training speed with about 50% (on MacBookPro CPU, tensorflow backend). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. I think this is a useful example and we can include it. However, it seems quite long. Is there anything you could afford to leave out?
initializers, | ||
regularizers, | ||
constraints) | ||
from keras.engine import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style: don't import from engine
(it's an internal factoring module). Instead do:
import keras
from keras import layers
Then use e.g. layers.Dense
[target_seqs_val[:, :-1], input_seqs_val], | ||
target_seqs_val[:, 1:, None])) | ||
|
||
# TODO add logic for greedy/beam search generation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think stopping the example at model.fit
is too restrictive, this should be an end-to-end example showing how to do inference as well (like we do in the other translation example).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally agree. Beam search will make the example even significantly longer more complex (but is most relevant). Will add both greedy and beam search inference to the example and then we can decide.
@fchollet What about moving the |
…notationAttention
…attention_standalone_example_efficient
…ttention_standalone_example
…ttention_standalone_example
Hi @farizrahman4u, @gabrieldemarmiesse @lvapeab @fchollet! I found some time to properly validate this implementation (there were some subtle bugs) and fix all the remaining TODOs. It achieves "decent" performance for the given dataset in an hour on a K80 (the original paper used a 1000x larger dataset and trained for several days). As discussed, the example is long - but it is also a complete replication of the (quite old but prominent) paper on recurrent attention, including beam-search readout. I think it serves as a good reference. As pointed out before, we can continue the discussion regarding if some parts should be added to the core API (if/when there is time) and simplify this example accordingly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. I'm not an expert in RNNs, but I hope I can help make this PR better.
return self.score < other.score | ||
|
||
def __gt__(self, other): | ||
return other.score > other.score |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is a mistake here. Maybe return self.score > other.score
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
...the headpush/pop
only uses __lt__
so on wont have affected the results.
elif len(beams_updated) < search_width: | ||
# not full search width | ||
heapq.heappush(beams_updated, new_beam) | ||
elif new_beam.score > beams_updated[0].score: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe elif new_beam > beams_updated[0]
? Otherwise __gt__
and __lt__
are never used in your example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I can do this - but yes they are anyway, in the heapq - this was the main reason for implementing comparison methods of the Beam
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Thanks for the explanation!
|
||
|
||
if __name__ == '__main__': | ||
DATA_DIR = 'data/wmt16_mmt' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I suggest the following:
from keras.utils.data_utils import get_file
base_name = 'wmt16_mmt_'
origin = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/'
get_file(base_name + 'train', origin=origin + 'training.tar.gz', untar=True)
get_file(base_name + 'val', origin=origin + 'validation.tar.gz', untar=True)
tar_file = get_file(base_name + 'test',
origin=origin + 'mmt16_task1_test.tar.gz',
untar=True)
DATA_DIR = os.path.dirname(tar_file)
Taking "standalone" a step further.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the extracted files won't have the base_name
I create a new cache_subdir
instead - otherwise files with a very generic name (train.en
) ends up in the datasets
dir.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, I didn't find a way to make a subdir (I didn't look much into it).
- NOTE that a different dataset (wmt14) is used in [1], which is _orders of | ||
magnitude_ larger than the dataset used here (348M vs 0.35M words). The model | ||
in [1] was trained for 252 hours (!) on a Tesla Quadro K6000, whereas for the | ||
data in this example the model starts to overfit after < 1 hour (15 epochs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is <
a typo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, I meant "less than", but better type it out.
@gabrieldemarmiesse This diff: https://github.com/andhus/keras/pull/6/files shows how we can skip 300+ lines by removing the base class For the standalone example, this makes sense I guess. The drawback is that more overhead is added to the implementation of the specific attention mechanism The most "pressing" need for (something like) the base class in the core API is that we need to use the private |
…auto-download data, fix Beam comp.)
Good question. I don't know much about it. Maybe someone else can give some insight? @farizrahman4u do you know how we can avoid calling the private function? |
@gabrieldemarmiesse @farizrahman4u @fchollet I'd love to wrap-up this one. It has been reviewed, trained until convergence and sanity checked. It adds clear value as there are no attention examples currently. The remaining question was whether to get rid of the base class or not. I vote for removing it, i.e. apply this diff https://github.com/andhus/keras/pull/6/files and get rid of 300 lines. It was never intended for the example (but to standardize and remove boilerplate for attention cell wrappers in general). Given current speed of progress :) I don't think it is reasonable to think that this will be added to the core API in the near future (it can always be found in history of this PR). For reference (@gabrieldemarmiesse 4 Nov 2018):
|
Hi @andhus - thanks so much for your work on this. I'm excited to be able use this. I tried the standalone example in Keras 2.2.4 using Tensorflow 1.14.0. It died in the K.rnn call at line 2974 of tensorflow_backend.py where it does: This is only a problem in K.rnn if there is masking so when I turned off mask_zero in the target sequence embedding the code started to run. I'm not sure where the break down in the logic is happening. Do you have any idea how this could be fixed? In regards to the code generally, I have a few thoughts:
Again, I'm very impressed by this and it would be great to see this part of Keras proper in the near future! |
Hi,@rbturnbull |
return output_texts, output_scores | ||
|
||
# Translate first 3 samples from validation data | ||
for input_text, target_text in zip(input_texts_val, target_texts_val)[:3]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only line that breaks running under PY3, change line 926 to:
for input_text, target_text in list(zip(input_texts_val, target_texts_val))[:3]:
zip is eager in PY2, lazy in PY3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great PR @andhus and I hope to see it merged soon so that I can evangelize how one can easily use attention in Keras.
I pulled the branch and ran it successfully using:
Keras==2.3.1
tensorflow_gpu == 2.0.0
Python 3.7
with one small modification, changed line 926 so that it will run under PY3 where zip is lazy eval.
Thanks @andhus for this outstanding PR! I hope this gets merged soon – it's been over 2 years (taking into account the ones leading up here). I am also in favour of keeping the base class in the example, as this would allow making follow-up PRs both for incorporating it into the base API and for adding other attention mechanisms (or features like alignment output/visualization, local attention etc) independently. There is but one issue which I think should be addressed/fixed: With the current implementation, one cannot make use of Keras' layer sharing with The reason is that the constructor of def __init__(self, cell, ...):
# self.cell = cell
self._set_cell(cell)
...
@disable_tracking
def _set_cell(self, cell):
self.cell = cell |
Summary
Standalone example of recurrent attention as per @farizrahman4u suggestion. There is thorough documentation in the script itself.
The script contains a base class for recurrent attention mechanisms. The purpose of this is to make it simple to write custom attention mechanisms. This is the main logic needed to implement the specific mechanism (by extending the base class):
The lines below summarizes how the attention mechanism is used, in summary: an RNNCell is wrapped by the attention mechanism and the attended constans are provided to the RNN:
Related Issues
#11172 (+multiple previous issues and PRs linked from there)
PR Overview
The PR contains a single example script. It is under review/discussion what parts might make it into the core api.
TODO, definitely needed if
RNNAttentionCell
is added to core api. Tests should be added/done also to validate implementation in this example.