Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformer model based on Tensorlayer #1027

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add examples
  • Loading branch information
Lingjun Liu committed Sep 13, 2019
commit 90d536e5959ed38364c0a343eb682d8385a8d8be
5 changes: 5 additions & 0 deletions docs/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,8 @@ Seq2seq Luong Attention
------------------------

.. autoclass:: Seq2seqLuongAttention

Transformer
------------------------

.. autoclass:: Transformer
168 changes: 168 additions & 0 deletions examples/translation_task/tutorial_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow_datasets as tfds
import tensorflow as tf
import time
import numpy as np
import matplotlib.pyplot as plt
from tensorlayer.models.transformer import Transformer
from tensorlayer.models.transformer.utils import metrics
from tensorlayer.models.transformer.utils import attention_visualisation
import tensorlayer as tl


""" Translation from Portugese to English by Transformer model
This tutorial provides basic instructions on how to define and train Transformer model on Tensorlayer for
Translation task. You can also learn how to visualize the attention block via this tutorial.
"""

def set_up_dataset():
# Set up dataset for Portugese-English translation from the TED Talks Open Translation Project.
# This dataset contains approximately 50000 training examples, 1100 validation examples, and 2000 test examples.
# https://www.ted.com/participate/translate

examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', with_info=True,
as_supervised=True)
train_examples, val_examples = examples['train'], examples['validation']

# Set up tokenizer and save the tokenizer
tokenizer = tfds.features.text.SubwordTextEncoder.build_from_corpus(
(en.numpy() and pt.numpy() for pt, en in train_examples), target_vocab_size=2**14)

tokenizer.save_to_file("tokenizer")
tokenizer = tfds.features.text.SubwordTextEncoder.load_from_file("tokenizer")

return tokenizer, train_examples


def test_tokenizer_success(tokenizer):
sample_string = 'TensorLayer is awesome.'

tokenized_string = tokenizer.encode(sample_string)
print ('Tokenized string is {}'.format(tokenized_string))

original_string = tokenizer.decode(tokenized_string)
print ('The original string: {}'.format(original_string))
assert original_string == sample_string



def generate_training_dataset(train_examples, tokenizer):
def encode(lang1, lang2):
lang1 = tokenizer.encode(
lang1.numpy()) + [tokenizer.vocab_size+1]

lang2 = tokenizer.encode(
lang2.numpy()) + [tokenizer.vocab_size+1]

return lang1, lang2
MAX_LENGTH = 50
def filter_max_length(x, y, max_length=MAX_LENGTH):
return tf.logical_and(tf.size(x) <= max_length,
tf.size(y) <= max_length)
def tf_encode(pt, en):
return tf.py_function(encode, [pt, en], [tf.int64, tf.int64])
train_dataset = train_examples.map(tf_encode)
train_dataset = train_dataset.filter(filter_max_length)
# cache the dataset to memory to get a speedup while reading from it.
train_dataset = train_dataset.cache()
BUFFER_SIZE = 20000
BATCH_SIZE = 64
train_dataset = train_dataset.shuffle(BUFFER_SIZE).padded_batch(
BATCH_SIZE, padded_shapes=([-1], [-1]))
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)

return train_dataset




def model_setup(tokenizer):
# define Hyper parameters for transformer
class HYPER_PARAMS(object):
vocab_size = tokenizer.vocab_size + 10
encoder_num_layers = 4
decoder_num_layers = 4
hidden_size = 128
ff_size = 512
num_heads = 8
keep_prob = 0.9

# Default prediction params
extra_decode_length = 50
beam_size = 5
alpha = 0.6 # used to calculate length normalization in beam search


label_smoothing=0.1
learning_rate=2.0
learning_rate_decay_rate=1.0
learning_rate_warmup_steps=4000

sos_id = 0
eos_id = tokenizer.vocab_size+1


model = Transformer(HYPER_PARAMS)

# Set the optimizer
learning_rate = CustomSchedule(HYPER_PARAMS.hidden_size, warmup_steps=HYPER_PARAMS.learning_rate_warmup_steps)
optimizer = tl.optimizers.LazyAdamOptimizer(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
return model, optimizer, HYPER_PARAMS


# Use the Adam optimizer with a custom learning rate scheduler according to the formula in the Paper "Attention is All you need"
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, d_model, warmup_steps=5):
super(CustomSchedule, self).__init__()

self.d_model = d_model
self.d_model = tf.cast(self.d_model, tf.float32)

self.warmup_steps = warmup_steps

def __call__(self, step):
arg1 = tf.math.rsqrt(step)
arg2 = step * (self.warmup_steps ** -1.5)

return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)



def tutorial_transformer():
tokenizer, train_examples = set_up_dataset()
train_dataset = generate_training_dataset(train_examples, tokenizer)
model, optimizer, HYPER_PARAMS = model_setup(tokenizer)

num_epochs = 10
for epoch in range(num_epochs):
model.train()
for (batch, (inp, tar)) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits, weights_encoder, weights_decoder = model(inputs=inp, targets=tar)
logits = metrics.MetricLayer(HYPER_PARAMS.vocab_size)([logits, tar])
logits, loss = metrics.LossLayer(HYPER_PARAMS.vocab_size, 0.1)([logits, tar])
grad = tape.gradient(loss, model.all_weights)
optimizer.apply_gradients(zip(grad, model.all_weights))
if (batch % 50 == 0):
print('Batch ID {} at Epoch [{}/{}]: loss {:.4f}'.format(batch, epoch + 1, num_epochs, loss))



model.eval()
sentence_en = tokenizer.encode('TensorLayer is awesome.')
[prediction, weights_decoder], weights_encoder = model(inputs=[sentence_en])

predicted_sentence = tokenizer.decode([i for i in prediction["outputs"][0]
if i < tokenizer.vocab_size])
print("Translated: ", predicted_sentence)


# visualize the self attention
tokenizer_str = [tokenizer.decode([ts]) for ts in (sentence_en)]
attention_visualisation.plot_attention_weights(weights_encoder["layer_0"], tokenizer_str, tokenizer_str)




if __name__ == "__main__":
tutorial_transformer()
10 changes: 5 additions & 5 deletions tensorlayer/models/transformer/beamsearchHelper/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def search(self, initial_ids, initial_cache):
finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]

# Account for corner case where there are no finished sequences for a
# particular batch item. In that case, return alive sequences for that batch
# item.
finished_seq = tf.where(tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
finished_scores = tf.where(tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
# # Account for corner case where there are no finished sequences for a
# # particular batch item. In that case, return alive sequences for that batch
# # item.
# finished_seq = tf.where(tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
# finished_scores = tf.where(tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
return finished_seq, finished_scores


Expand Down
4 changes: 2 additions & 2 deletions tensorlayer/models/transformer/feedforward_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import tensorlayer as tl


class FeedForwardLayer(tl.layers.Layer):
class TransformerFeedForwardLayer(tl.layers.Layer):
"""Fully connected feedforward network."""

def __init__(self, hidden_size, filter_size, keep_prob):
Expand All @@ -33,7 +33,7 @@ def __init__(self, hidden_size, filter_size, keep_prob):
filter_size: int, filter size for the inner (first) dense layer.
relu_dropout: float, dropout rate for training.
"""
super(FeedForwardLayer, self).__init__()
super(TransformerFeedForwardLayer, self).__init__()
self.hidden_size = hidden_size
self.filter_size = filter_size
self.relu_dropout = 1 - keep_prob
Expand Down
14 changes: 8 additions & 6 deletions tensorlayer/models/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tensorlayer.models import Model
import tensorlayer.models.transformer.embedding_layer as embedding_layer
from tensorlayer.models.transformer.attention_layer import SelfAttentionLayer, MultiHeadAttentionLayer
from tensorlayer.models.transformer.feedforward_layer import FeedForwardLayer
from tensorlayer.models.transformer.feedforward_layer import TransformerFeedForwardLayer
from tensorlayer.models.transformer.utils.model_utils import positional_encoding
from tensorlayer.models.transformer.utils.model_utils import get_decoder_self_attention_bias as get_target_mask
from tensorlayer.models.transformer.utils.model_utils import get_padding_bias as get_input_mask
Expand Down Expand Up @@ -56,6 +56,8 @@ class Transformer(Model):
>>> extra_decode_length = 5
>>> beam_size = 1
>>> alpha = 0.6
>>> eos_id = 1
>>> sos_id = 0
>>> model = Transformer(TINY_PARAMS)

Returns
Expand Down Expand Up @@ -224,7 +226,7 @@ def decode(self, targets, encoder_outputs, attention_bias):
decoder_inputs = self.embedding_softmax_layer(targets)
with tf.name_scope("shift_targets"):
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]], constant_values=self.params.sos_id)[:, :-1, :]
with tf.name_scope("add_pos_encoding"):
length = tf.shape(decoder_inputs)[1]
decoder_inputs += positional_encoding(length, self.params.hidden_size)
Expand Down Expand Up @@ -294,7 +296,7 @@ def predict(self, encoder_outputs, encoder_decoder_attention_bias):
symbols_to_logits_fn, weights = self._get_symbols_to_logits_fn(max_decode_length)

# Create initial set of IDs that will be passed into symbols_to_logits_fn.
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
initial_ids = tf.ones([batch_size], dtype=tf.int32)*self.params.sos_id

# Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
Expand All @@ -314,7 +316,7 @@ def predict(self, encoder_outputs, encoder_decoder_attention_bias):
decoded_ids, scores = beam_search.sequence_beam_search(
symbols_to_logits_fn=symbols_to_logits_fn, initial_ids=initial_ids, initial_cache=cache,
vocab_size=self.params.vocab_size, beam_size=self.params.beam_size, alpha=self.params.alpha,
max_decode_length=max_decode_length, eos_id=1
max_decode_length=max_decode_length, eos_id=self.params.eos_id
)

# Get the top sequence for each batch element
Expand Down Expand Up @@ -421,7 +423,7 @@ def __init__(self, params):
for _ in range(params.encoder_num_layers):
# Create sublayers for each layer.
self_attention_layer = SelfAttentionLayer(params.num_heads, params.hidden_size, params.keep_prob)
feed_forward_network = FeedForwardLayer(params.hidden_size, params.ff_size, params.keep_prob)
feed_forward_network = TransformerFeedForwardLayer(params.hidden_size, params.ff_size, params.keep_prob)

self.layers.append(
[
Expand Down Expand Up @@ -488,7 +490,7 @@ def __init__(self, params):
for _ in range(params.decoder_num_layers):
self_attention_layer = SelfAttentionLayer(params.num_heads, params.hidden_size, params.keep_prob)
enc_dec_attention_layer = MultiHeadAttentionLayer(params.num_heads, params.hidden_size, params.keep_prob)
feed_forward_network = FeedForwardLayer(params.hidden_size, params.ff_size, params.keep_prob)
feed_forward_network = TransformerFeedForwardLayer(params.hidden_size, params.ff_size, params.keep_prob)

self.layers.append(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def plot_attention_weights(attention, key, query):
'''

fig = plt.figure(figsize=(16, 8))

attention = tf.squeeze(attention, axis=0)

for head in range(attention.shape[0]):
Expand Down
1 change: 1 addition & 0 deletions tensorlayer/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
"""

from .amsgrad import AMSGrad
from .lazy_adam import LazyAdamOptimizer
76 changes: 76 additions & 0 deletions tensorlayer/optimizers/lazy_adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimizer from addons and learning rate scheduler."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf


class LazyAdamOptimizer(tf.optimizers.Adam):
"""Variant of the Adam optimizer that handles sparse updates more efficiently.

The original Adam algorithm maintains two moving-average accumulators for
each trainable variable; the accumulators are updated at every step.
This class provides lazier handling of gradient updates for sparse
variables. It only updates moving-average accumulators for sparse variable
indices that appear in the current batch, rather than updating the
accumulators for all indices. Compared with the original Adam optimizer,
it can provide large improvements in model training throughput for some
applications. However, it provides slightly different semantics than the
original Adam algorithm, and may lead to different empirical results.
Note, amsgrad is currently not supported and the argument can only be
False.

This class is borrowed from:
https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/lazy_adam.py
"""

def _resource_apply_sparse(self, grad, var, indices):
"""Applies grad for one step."""
var_dtype = var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
beta_1_t = self._get_hyper('beta_1', var_dtype)
beta_2_t = self._get_hyper('beta_2', var_dtype)
local_step = tf.cast(self.iterations + 1, var_dtype)
beta_1_power = tf.math.pow(beta_1_t, local_step)
beta_2_power = tf.math.pow(beta_2_t, local_step)
epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype)
lr = (lr_t * tf.math.sqrt(1 - beta_2_power) / (1 - beta_1_power))

# \\(m := beta1 * m + (1 - beta1) * g_t\\)
m = self.get_slot(var, 'm')
m_t_slice = beta_1_t * tf.gather(m, indices) + (1 - beta_1_t) * grad

m_update_kwargs = {'resource': m.handle, 'indices': indices, 'updates': m_t_slice}
m_update_op = tf.raw_ops.ResourceScatterUpdate(**m_update_kwargs)

# \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
v = self.get_slot(var, 'v')
v_t_slice = (beta_2_t * tf.gather(v, indices) + (1 - beta_2_t) * tf.math.square(grad))

v_update_kwargs = {'resource': v.handle, 'indices': indices, 'updates': v_t_slice}
v_update_op = tf.raw_ops.ResourceScatterUpdate(**v_update_kwargs)

# \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
var_slice = lr * m_t_slice / (tf.math.sqrt(v_t_slice) + epsilon_t)

var_update_kwargs = {'resource': var.handle, 'indices': indices, 'updates': var_slice}
var_update_op = tf.raw_ops.ResourceScatterSub(**var_update_kwargs)

return tf.group(*[var_update_op, m_update_op, v_update_op])
Loading