Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recurrent Attention: standalone machine translation example #11421

Closed

Conversation

andhus
Copy link
Contributor

@andhus andhus commented Oct 17, 2018

Summary

Standalone example of recurrent attention as per @farizrahman4u suggestion. There is thorough documentation in the script itself.

The script contains a base class for recurrent attention mechanisms. The purpose of this is to make it simple to write custom attention mechanisms. This is the main logic needed to implement the specific mechanism (by extending the base class):

def attention_call(self,
                   inputs,
                   cell_states,
                   attended,
                   attention_states,
                   attended_mask,
                   training=None):
    # only one attended sequence (verified in build)
    assert len(attended) == 1
    attended = attended[0]
    attended_mask = attended_mask[0]
    h_cell_tm1 = cell_states[0]

    # compute attention weights
    w = K.repeat(K.dot(h_cell_tm1, self.W_a), K.shape(attended)[1])
    u = K.dot(attended, self.U_a)
    e = K.exp(K.dot(K.tanh(w + u), self.v_a))

    if attended_mask is not None:
        e = e * K.cast(K.expand_dims(attended_mask, -1), K.dtype(e))

    # weighted average of attended
    a = e / K.sum(e, axis=1, keepdims=True)
    c = K.sum(a * attended, axis=1, keepdims=False)

    return c, [c]

The lines below summarizes how the attention mechanism is used, in summary: an RNNCell is wrapped by the attention mechanism and the attended constans are provided to the RNN:

decoder = RNN(
    cell=DenseAnnotationAttention(
        cell=GRUCell(RECURRENT_UNITS),
        units=DENSE_ATTENTION_UNITS),
    return_sequences=True)
h1 = decoder(y_emb, constants=x_enc)

Related Issues

#11172 (+multiple previous issues and PRs linked from there)

PR Overview

The PR contains a single example script. It is under review/discussion what parts might make it into the core api.

  • [y] This PR requires new unit tests [y/n] (make sure tests are included)
    TODO, definitely needed if RNNAttentionCell is added to core api. Tests should be added/done also to validate implementation in this example.
  • [?] This PR requires to update the documentation [y/n] (make sure the docs are up-to-date)
  • [y] This PR is backwards compatible [y/n]
  • [n] This PR changes the current API [y/n] (all API changes need to be approved by fchollet)

batch_size=BATCH_SIZE,
epochs=EPOCHS,
validation_data=(
[target_seqs_train[:, :-1], input_seqs_train],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use validation data!

@farizrahman4u
Copy link
Contributor

@fchollet This is a very neat and thorough PR. Please review and discuss what parts needs to be moved into the Keras API and what should stay in the example.

return K.max(K.stack([x_1, x_2], axis=-1), axis=-1, keepdims=False)

h2 = TimeDistributed(Lambda(dense_maxout))(concatenate([h1, y_emb]))
y_pred = TimeDistributed(Dense(target_tokenizer.num_words))(h2)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Softmax missing!

@andhus
Copy link
Contributor Author

andhus commented Oct 18, 2018

Regarding TODO(4) in the docs: This diff clarifies the changes needed to improve the efficiency of the attention mechanism. It is a little bit less intuitive, why I left it out of this PR. It boosts training speed with about 50% (on MacBookPro CPU, tensorflow backend).

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. I think this is a useful example and we can include it. However, it seems quite long. Is there anything you could afford to leave out?

initializers,
regularizers,
constraints)
from keras.engine import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style: don't import from engine (it's an internal factoring module). Instead do:

import keras
from keras import layers

Then use e.g. layers.Dense

[target_seqs_val[:, :-1], input_seqs_val],
target_seqs_val[:, 1:, None]))

# TODO add logic for greedy/beam search generation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think stopping the example at model.fit is too restrictive, this should be an end-to-end example showing how to do inference as well (like we do in the other translation example).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree. Beam search will make the example even significantly longer more complex (but is most relevant). Will add both greedy and beam search inference to the example and then we can decide.

@farizrahman4u
Copy link
Contributor

@fchollet What about moving the RNNAttentionCell class to Keras?

@andhus
Copy link
Contributor Author

andhus commented Dec 27, 2018

Hi @farizrahman4u, @gabrieldemarmiesse @lvapeab @fchollet! I found some time to properly validate this implementation (there were some subtle bugs) and fix all the remaining TODOs. It achieves "decent" performance for the given dataset in an hour on a K80 (the original paper used a 1000x larger dataset and trained for several days).

As discussed, the example is long - but it is also a complete replication of the (quite old but prominent) paper on recurrent attention, including beam-search readout. I think it serves as a good reference. As pointed out before, we can continue the discussion regarding if some parts should be added to the core API (if/when there is time) and simplify this example accordingly.

@gabrieldemarmiesse
Copy link
Contributor

Thanks a lot @andhus for your work. This must have taken a lot of time. I'll take a look at it tomorrow for a first review, and I think @fchollet will also read it when he has more time.

Copy link
Contributor

@gabrieldemarmiesse gabrieldemarmiesse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. I'm not an expert in RNNs, but I hope I can help make this PR better.

return self.score < other.score

def __gt__(self, other):
return other.score > other.score
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a mistake here. Maybe return self.score > other.score ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...the headpush/pop only uses __lt__ so on wont have affected the results.

elif len(beams_updated) < search_width:
# not full search width
heapq.heappush(beams_updated, new_beam)
elif new_beam.score > beams_updated[0].score:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe elif new_beam > beams_updated[0] ? Otherwise __gt__ and __lt__ are never used in your example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I can do this - but yes they are anyway, in the heapq - this was the main reason for implementing comparison methods of the Beam.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks for the explanation!



if __name__ == '__main__':
DATA_DIR = 'data/wmt16_mmt'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I suggest the following:

    from keras.utils.data_utils import get_file
    base_name = 'wmt16_mmt_'
    origin = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/'
    get_file(base_name + 'train', origin=origin + 'training.tar.gz', untar=True)
    get_file(base_name + 'val', origin=origin + 'validation.tar.gz', untar=True)
    tar_file = get_file(base_name + 'test',
                        origin=origin + 'mmt16_task1_test.tar.gz',
                        untar=True)

    DATA_DIR = os.path.dirname(tar_file)

Taking "standalone" a step further.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure :D

Copy link
Contributor Author

@andhus andhus Dec 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the extracted files won't have the base_name I create a new cache_subdir instead - otherwise files with a very generic name (train.en) ends up in the datasets dir.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I didn't find a way to make a subdir (I didn't look much into it).

- NOTE that a different dataset (wmt14) is used in [1], which is _orders of
magnitude_ larger than the dataset used here (348M vs 0.35M words). The model
in [1] was trained for 252 hours (!) on a Tesla Quadro K6000, whereas for the
data in this example the model starts to overfit after < 1 hour (15 epochs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is < a typo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, I meant "less than", but better type it out.

@andhus
Copy link
Contributor Author

andhus commented Dec 28, 2018

@gabrieldemarmiesse This diff: https://github.com/andhus/keras/pull/6/files shows how we can skip 300+ lines by removing the base class AttentionCellWrapper (100 lines is just docs of the base class).

For the standalone example, this makes sense I guess. The drawback is that more overhead is added to the implementation of the specific attention mechanism DenseAnnotationAttention.

The most "pressing" need for (something like) the base class in the core API is that we need to use the private keras.engine.base_layer._collect_previous_mask method to extract the masks of the attended tensor(s). In the AttentionCellWrapper, the masks are extracted and explicitly passed to the attention_call (abstract) method. Or is there another way using only API functionality to extract the masks?

@gabrieldemarmiesse
Copy link
Contributor

Good question. I don't know much about it. Maybe someone else can give some insight? @farizrahman4u do you know how we can avoid calling the private function?

@andhus
Copy link
Contributor Author

andhus commented Mar 27, 2019

@gabrieldemarmiesse @farizrahman4u @fchollet I'd love to wrap-up this one. It has been reviewed, trained until convergence and sanity checked. It adds clear value as there are no attention examples currently.

The remaining question was whether to get rid of the base class or not. I vote for removing it, i.e. apply this diff https://github.com/andhus/keras/pull/6/files and get rid of 300 lines. It was never intended for the example (but to standardize and remove boilerplate for attention cell wrappers in general). Given current speed of progress :) I don't think it is reasonable to think that this will be added to the core API in the near future (it can always be found in history of this PR).

For reference (@gabrieldemarmiesse 4 Nov 2018):

On an organisation note (because I can't say I understand very well what is going on), I would suggest to

  1. Add this example to the examples directory, that is, merging this PR since there seems to be a consensus about the quality of this example.
  2. Discuss later in another PR what should leave this example and go in the codebase. This is because this step will surely include a rework of the documentation + tests.

I propose doing this in two steps because the time to process a PR is usually an exponential function of the changes.

@rbturnbull
Copy link

Hi @andhus - thanks so much for your work on this. I'm excited to be able use this. I tried the standalone example in Keras 2.2.4 using Tensorflow 1.14.0. It died in the K.rnn call at line 2974 of tensorflow_backend.py where it does:
output = tf.where(tiled_mask_t, output, states[0])
tf.where needs the x and y tensors to be the same shape but output in the demo is cell_output concatenated with attention_h and states[0] (i.e. [?,3000]) and states[0] cell_state from from the GRU which is [?,1000].

This is only a problem in K.rnn if there is masking so when I turned off mask_zero in the target sequence embedding the code started to run.

I'm not sure where the break down in the logic is happening. Do you have any idea how this could be fixed?

In regards to the code generally, I have a few thoughts:

  • I really like that you have the AttentionCellWrapper class and I recommend that you keep it. It will make it easier to add in other types of attention such as the Multiplicative Attention from Luong's 2015 paper.
  • I think that it would be better to wrap the building of the u tensor into the DenseAnnotationAttention layer. Perhaps you could call the existing class DenseAnnotationAttentionCell and then make a new class called just DenseAnnotationAttention where in the call function you build the u tensor and then return RNN(cell=cell, return_sequences=True). Does that make sense?
  • I'm not sure that name DenseAnnotationAttention will be clear for users. Maybe BahdanauAttention or AdditiveAttension would be clearer and match how the mechanism is talked about in the literature.
  • Finally, I think it's really important to be able to output the attention weights somehow because often these are used to use as a kind of soft alignment (and these are often visualized in papers). Do you know how this could be an optional output?

Again, I'm very impressed by this and it would be great to see this part of Keras proper in the near future!

@JoyceCoder
Copy link

Hi @andhus - thanks so much for your work on this. I'm excited to be able use this. I tried the standalone example in Keras 2.2.4 using Tensorflow 1.14.0. It died in the K.rnn call at line 2974 of tensorflow_backend.py where it does:
output = tf.where(tiled_mask_t, output, states[0])
tf.where needs the x and y tensors to be the same shape but output in the demo is cell_output concatenated with attention_h and states[0] (i.e. [?,3000]) and states[0] cell_state from from the GRU which is [?,1000].

This is only a problem in K.rnn if there is masking so when I turned off mask_zero in the target sequence embedding the code started to run.

I'm not sure where the break down in the logic is happening. Do you have any idea how this could be fixed?

In regards to the code generally, I have a few thoughts:

  • I really like that you have the AttentionCellWrapper class and I recommend that you keep it. It will make it easier to add in other types of attention such as the Multiplicative Attention from Luong's 2015 paper.
  • I think that it would be better to wrap the building of the u tensor into the DenseAnnotationAttention layer. Perhaps you could call the existing class DenseAnnotationAttentionCell and then make a new class called just DenseAnnotationAttention where in the call function you build the u tensor and then return RNN(cell=cell, return_sequences=True). Does that make sense?
  • I'm not sure that name DenseAnnotationAttention will be clear for users. Maybe BahdanauAttention or AdditiveAttension would be clearer and match how the mechanism is talked about in the literature.
  • Finally, I think it's really important to be able to output the attention weights somehow because often these are used to use as a kind of soft alignment (and these are often visualized in papers). Do you know how this could be an optional output?

Again, I'm very impressed by this and it would be great to see this part of Keras proper in the near future!

Hi,@rbturnbull
I met the same problem.But I pull this repo to my local,and use the tensorflow_backend in this repo,and slove this problem.
Maybe it can help you.

return output_texts, output_scores

# Translate first 3 samples from validation data
for input_text, target_text in zip(input_texts_val, target_texts_val)[:3]:
Copy link

@todd-cook todd-cook Oct 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only line that breaks running under PY3, change line 926 to:

for input_text, target_text in list(zip(input_texts_val, target_texts_val))[:3]:

zip is eager in PY2, lazy in PY3

Copy link

@todd-cook todd-cook left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great PR @andhus and I hope to see it merged soon so that I can evangelize how one can easily use attention in Keras.

I pulled the branch and ran it successfully using:
Keras==2.3.1
tensorflow_gpu == 2.0.0
Python 3.7

with one small modification, changed line 926 so that it will run under PY3 where zip is lazy eval.

@bertsky
Copy link

bertsky commented Feb 19, 2020

Thanks @andhus for this outstanding PR! I hope this gets merged soon – it's been over 2 years (taking into account the ones leading up here).

I am also in favour of keeping the base class in the example, as this would allow making follow-up PRs both for incorporating it into the base API and for adding other attention mechanisms (or features like alignment output/visualization, local attention etc) independently.

There is but one issue which I think should be addressed/fixed: With the current implementation, one cannot make use of Keras' layer sharing with DenseAnnotationAttention. (This is necessary to share the decoder weights when defining separate learning and inference models for the NMT encoder-decoder example.)

The reason is that the constructor of AttentionCellWrapper will assign the given cell directly to the instance, which causes the (inherited) attribute tracker to add it to _layers, and in turn the RNN.trainable_weights property will get to see the cell's weights as well. But AttentionCellWrapper's default implementation of that property already adds them. Hence there will be double references! As a fix, one can use the same trick as in RNN's constructor:

def __init__(self, cell, ...):
    # self.cell = cell
    self._set_cell(cell)
...
@disable_tracking
def _set_cell(self, cell):
    self.cell = cell

@fchollet fchollet closed this Dec 8, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants