From e12e41b54dd1b939cf7a45b9add542b29207c97a Mon Sep 17 00:00:00 2001 From: Zhang Jun Date: Thu, 3 Sep 2020 14:09:02 +0800 Subject: [PATCH] seq2seq with attention added --- .../seq2seq_with_attention.ipynb | 583 ++++++++++++++++++ 1 file changed, 583 insertions(+) create mode 100644 paddle2.0_docs/seq2seq_with_attention/seq2seq_with_attention.ipynb diff --git a/paddle2.0_docs/seq2seq_with_attention/seq2seq_with_attention.ipynb b/paddle2.0_docs/seq2seq_with_attention/seq2seq_with_attention.ipynb new file mode 100644 index 00000000..f92048bc --- /dev/null +++ b/paddle2.0_docs/seq2seq_with_attention/seq2seq_with_attention.ipynb @@ -0,0 +1,583 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0.0\n", + "7f2aa2db3c69cb9ebb8bae9e19280e75f964e1d0\n" + ] + } + ], + "source": [ + "import paddle\n", + "import paddle.nn.functional as F\n", + "import string\n", + "import re\n", + "import numpy as np\n", + "\n", + "paddle.disable_static()\n", + "print(paddle.__version__)\n", + "print(paddle.__git_commit__)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2020-09-03 14:14:55-- https://www.manythings.org/anki/cmn-eng.zip\n", + "Resolving www.manythings.org (www.manythings.org)... 104.24.108.196, 104.24.109.196, 172.67.173.198, ...\n", + "Connecting to www.manythings.org (www.manythings.org)|104.24.108.196|:443... connected.\n", + "HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable\n", + "\n", + " The file is already fully retrieved; nothing to do.\n", + "\n", + "Archive: cmn-eng.zip\n" + ] + } + ], + "source": [ + "!wget -c https://www.manythings.org/anki/cmn-eng.zip && unzip -f cmn-eng.zip" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 23610 cmn.txt\r\n" + ] + } + ], + "source": [ + "!wc -l cmn.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "MAX_LEN = 10" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5508\n", + "(['i', 'won'], ['我', '赢', '了', '。'])\n", + "(['he', 'ran'], ['他', '跑', '了', '。'])\n", + "(['i', 'quit'], ['我', '退', '出', '。'])\n", + "(['i', 'm', 'ok'], ['我', '沒', '事', '。'])\n", + "(['i', 'm', 'up'], ['我', '已', '经', '起', '来', '了', '。'])\n", + "(['we', 'try'], ['我', '们', '来', '试', '试', '。'])\n", + "(['he', 'came'], ['他', '来', '了', '。'])\n", + "(['he', 'runs'], ['他', '跑', '。'])\n", + "(['i', 'agree'], ['我', '同', '意', '。'])\n", + "(['i', 'm', 'ill'], ['我', '生', '病', '了', '。'])\n" + ] + } + ], + "source": [ + "\n", + "lines = open('cmn.txt', encoding='utf-8').read().strip().split('\\n')\n", + "words_re = re.compile(r'\\w+')\n", + "\n", + "pairs = []\n", + "for l in lines:\n", + " en_sent, cn_sent, _ = l.split('\\t')\n", + " pairs.append((words_re.findall(en_sent.lower()), list(cn_sent)))\n", + "\n", + "# create a smaller dataset to make the demo process faster\n", + "filtered_pairs = []\n", + "\n", + "for x in pairs:\n", + " if len(x[0]) < MAX_LEN and len(x[1]) < MAX_LEN and \\\n", + " x[0][0] in ('i', 'you', 'he', 'she', 'we', 'they'):\n", + " filtered_pairs.append(x)\n", + "\n", + " \n", + "print(len(filtered_pairs))\n", + "for x in filtered_pairs[:10]: print(x) " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 创建词表\n", + "\n", + "- 英文都变成了小写,去掉了标点符号。\n", + "- 中文未做分词,按照字做的切分。" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2539\n", + "2039\n" + ] + } + ], + "source": [ + "en_vocab = {}\n", + "cn_vocab = {}\n", + "\n", + "# create special token for unkown, begin of sentence, end of sentence\n", + "en_vocab[''], en_vocab[''], en_vocab[''] = 0, 1, 2\n", + "cn_vocab[''], cn_vocab[''], cn_vocab[''] = 0, 1, 2\n", + "\n", + "#print(en_vocab, cn_vocab)\n", + "\n", + "en_idx, cn_idx = 3, 3\n", + "\n", + "for en, cn in filtered_pairs:\n", + " for w in en: \n", + " if w not in en_vocab: \n", + " en_vocab[w] = en_idx\n", + " en_idx += 1\n", + " for w in cn: \n", + " if w not in cn_vocab: \n", + " cn_vocab[w] = cn_idx\n", + " cn_idx += 1\n", + "\n", + "print(len(list(en_vocab)))\n", + "print(len(list(cn_vocab)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 创建padding过的数据集" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(5508, 10)\n", + "(5508, 11)\n", + "(5508, 11)\n" + ] + } + ], + "source": [ + "# create padded datasets\n", + "padded_en_sents = []\n", + "padded_cn_sents = []\n", + "padded_cn_label_sents = []\n", + "for en, cn in filtered_pairs:\n", + " # reverse source sentence\n", + " padded_en_sent = en + [''] * (MAX_LEN - len(en))\n", + " padded_en_sent.reverse()\n", + " padded_cn_sent = [''] + cn + [''] * (MAX_LEN - len(cn))\n", + " padded_cn_label_sent = cn + [''] * (MAX_LEN - len(cn)) + ['']\n", + "\n", + " padded_en_sents.append([en_vocab[w] for w in padded_en_sent])\n", + " padded_cn_sents.append([cn_vocab[w] for w in padded_cn_sent])\n", + " padded_cn_label_sents.append([cn_vocab[w] for w in padded_cn_label_sent])\n", + "\n", + "train_en_sents = np.array(padded_en_sents)\n", + "train_cn_sents = np.array(padded_cn_sents)\n", + "train_cn_label_sents = np.array(padded_cn_label_sents)\n", + "\n", + "\n", + "print(train_en_sents.shape)\n", + "print(train_cn_sents.shape)\n", + "print(train_cn_label_sents.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 创建网络" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "embedding_size = 128\n", + "hidden_size = 256\n", + "num_encoder_lstm_layers = 1\n", + "en_vocab_size = len(list(en_vocab))\n", + "cn_vocab_size = len(list(cn_vocab))\n", + "epochs = 30\n", + "batch_size = 16" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# encoder: simply learn representation of source sentence\n", + "class Encoder(paddle.nn.Layer):\n", + " def __init__(self):\n", + " super(Encoder, self).__init__()\n", + " self.emb = paddle.nn.Embedding(size=[en_vocab_size, embedding_size],)\n", + " self.lstm = paddle.nn.LSTM(input_size=embedding_size, \n", + " hidden_size=hidden_size, \n", + " num_layers=num_encoder_lstm_layers,\n", + " dropout=0.5)\n", + "\n", + " def forward(self, x):\n", + " x = self.emb(x)\n", + " x, (_, _) = self.lstm(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# only move one step of LSTM, \n", + "# the recurrent loop is implemented inside training loop\n", + "class AttentionDecoder(paddle.nn.Layer):\n", + " def __init__(self):\n", + " super(AttentionDecoder, self).__init__()\n", + " self.emb = paddle.nn.Embedding(size=[cn_vocab_size, embedding_size],)\n", + " \n", + " # the lstm layer for to generate target sentence representation\n", + " self.lstm = paddle.nn.LSTM(input_size=embedding_size + hidden_size, \n", + " hidden_size=hidden_size, \n", + " dropout=0.5)\n", + " \n", + " # for computing attention weights\n", + " self.attention_linear1 = paddle.nn.Linear(hidden_size * 2, hidden_size)\n", + " self.attention_linear2 = paddle.nn.Linear(hidden_size, 1)\n", + " \n", + " # for computing output logits\n", + " self.outlinear =paddle.nn.Linear(hidden_size, cn_vocab_size)\n", + "\n", + "\n", + " def forward(self, x, previous_hidden, previous_cell, encoder_outputs):\n", + " x = self.emb(x)\n", + " \n", + " attention_inputs = paddle.concat((encoder_outputs, \n", + " paddle.tile(previous_hidden, repeat_times=[1, MAX_LEN, 1])),\n", + " axis=-1\n", + " )\n", + "\n", + " attention_hidden = self.attention_linear1(attention_inputs)\n", + " attention_hidden = F.tanh(attention_hidden)\n", + " attention_logits = self.attention_linear2(attention_hidden)\n", + " attention_logits = paddle.squeeze(attention_logits)\n", + "\n", + " \n", + " attention_weights = F.softmax(attention_logits) \n", + " attention_weights = paddle.expand_as(paddle.unsqueeze(attention_weights, -1), \n", + " encoder_outputs)\n", + "\n", + " context_vector = paddle.multiply(encoder_outputs, attention_weights) \n", + " context_vector = paddle.reduce_sum(context_vector, 1)\n", + " context_vector = paddle.unsqueeze(context_vector, 1)\n", + " \n", + " lstm_input = paddle.concat((x, context_vector), axis=-1)\n", + "\n", + " \n", + " # LSTM requires: timesteps * batch * hidden\n", + " previous_hidden = paddle.transpose(previous_hidden, [1, 0, 2])\n", + " previous_cell = paddle.transpose(previous_cell, [1, 0, 2])\n", + " \n", + " x, (hidden, cell) = self.lstm(lstm_input, (previous_hidden, previous_cell))\n", + " \n", + " # change the return to batch * timesteps * hidden \n", + " hidden = paddle.transpose(hidden, [1, 0, 2])\n", + " cell = paddle.transpose(cell, [1, 0, 2])\n", + "\n", + " output = self.outlinear(hidden)\n", + " output = paddle.squeeze(output)\n", + " return output, (hidden, cell)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch:0\n", + "iter 0, loss:[7.6200185]\n", + "iter 200, loss:[3.4169505]\n", + "epoch:1\n", + "iter 0, loss:[3.1581175]\n", + "iter 200, loss:[3.3032415]\n", + "epoch:2\n", + "iter 0, loss:[3.0002146]\n", + "iter 200, loss:[3.2185385]\n", + "epoch:3\n", + "iter 0, loss:[2.9653757]\n", + "iter 200, loss:[3.098806]\n", + "epoch:4\n", + "iter 0, loss:[2.6793027]\n", + "iter 200, loss:[2.5913079]\n", + "epoch:5\n", + "iter 0, loss:[2.6999655]\n", + "iter 200, loss:[2.379569]\n", + "epoch:6\n", + "iter 0, loss:[2.5435457]\n", + "iter 200, loss:[2.748782]\n", + "epoch:7\n", + "iter 0, loss:[2.2716467]\n", + "iter 200, loss:[2.608843]\n", + "epoch:8\n", + "iter 0, loss:[2.406243]\n", + "iter 200, loss:[2.0575686]\n", + "epoch:9\n", + "iter 0, loss:[1.8435733]\n", + "iter 200, loss:[2.2846653]\n", + "epoch:10\n", + "iter 0, loss:[1.7847126]\n", + "iter 200, loss:[1.9135032]\n", + "epoch:11\n", + "iter 0, loss:[1.7565953]\n", + "iter 200, loss:[1.8443459]\n", + "epoch:12\n", + "iter 0, loss:[1.4571258]\n", + "iter 200, loss:[1.7388061]\n", + "epoch:13\n", + "iter 0, loss:[1.4517817]\n", + "iter 200, loss:[1.6169605]\n", + "epoch:14\n", + "iter 0, loss:[1.4762214]\n", + "iter 200, loss:[1.4081928]\n", + "epoch:15\n", + "iter 0, loss:[1.5862186]\n", + "iter 200, loss:[1.2722157]\n", + "epoch:16\n", + "iter 0, loss:[1.2187248]\n", + "iter 200, loss:[1.4003304]\n", + "epoch:17\n", + "iter 0, loss:[1.2493218]\n", + "iter 200, loss:[1.2339343]\n", + "epoch:18\n", + "iter 0, loss:[1.1312847]\n", + "iter 200, loss:[0.964386]\n", + "epoch:19\n", + "iter 0, loss:[1.1658673]\n", + "iter 200, loss:[0.8626503]\n", + "epoch:20\n", + "iter 0, loss:[0.82285637]\n", + "iter 200, loss:[0.9566538]\n", + "epoch:21\n", + "iter 0, loss:[0.9608566]\n", + "iter 200, loss:[0.9659755]\n", + "epoch:22\n", + "iter 0, loss:[0.81096345]\n", + "iter 200, loss:[0.83643824]\n", + "epoch:23\n", + "iter 0, loss:[0.7429311]\n", + "iter 200, loss:[0.6180183]\n", + "epoch:24\n", + "iter 0, loss:[0.5948335]\n", + "iter 200, loss:[0.67955154]\n", + "epoch:25\n", + "iter 0, loss:[0.52399546]\n", + "iter 200, loss:[0.6174195]\n", + "epoch:26\n", + "iter 0, loss:[0.47286823]\n", + "iter 200, loss:[0.5419927]\n", + "epoch:27\n", + "iter 0, loss:[0.43044937]\n", + "iter 200, loss:[0.5685268]\n", + "epoch:28\n", + "iter 0, loss:[0.37578955]\n", + "iter 200, loss:[0.40272245]\n", + "epoch:29\n", + "iter 0, loss:[0.3834902]\n", + "iter 200, loss:[0.38272512]\n" + ] + } + ], + "source": [ + "encoder = Encoder()\n", + "atten_decoder = AttentionDecoder()\n", + "\n", + "opt = paddle.optimizer.Adam(learning_rate=0.001, \n", + " parameters=encoder.parameters()+atten_decoder.parameters())\n", + "\n", + "for epoch in range(epochs):\n", + " print(\"epoch:{}\".format(epoch))\n", + "\n", + " # shuffle training data\n", + " perm = np.random.permutation(len(train_en_sents))\n", + " train_en_sents_shuffled = train_en_sents[perm]\n", + " train_cn_sents_shuffled = train_cn_sents[perm]\n", + " train_cn_label_sents_shuffled = train_cn_label_sents[perm]\n", + "\n", + " \n", + " for iteration in range(train_en_sents_shuffled.shape[0] // batch_size):\n", + " x_data = train_en_sents_shuffled[(batch_size*iteration):(batch_size*(iteration+1))]\n", + " sent = paddle.to_tensor(x_data)\n", + " en_repr = encoder(sent)\n", + "\n", + " x_cn_data = train_cn_sents_shuffled[(batch_size*iteration):(batch_size*(iteration+1))]\n", + " x_cn_label_data = train_cn_label_sents_shuffled[(batch_size*iteration):(batch_size*(iteration+1))]\n", + "\n", + " # batch * num_layer(=1 in this example) * hidden_size\n", + " hidden = paddle.zeros([batch_size, 1, hidden_size])\n", + " cell = paddle.zeros([batch_size, 1, hidden_size])\n", + "\n", + " loss = paddle.zeros([1])\n", + " for i in range(MAX_LEN + 1):\n", + " cn_word = paddle.to_tensor(x_cn_data[:,i:i+1])\n", + " cn_word_label = paddle.to_tensor(x_cn_label_data[:,i:i+1])\n", + "\n", + " logits, (hidden, cell) = atten_decoder(cn_word, hidden, cell, en_repr)\n", + " step_loss = F.softmax_with_cross_entropy(logits, cn_word_label)\n", + " avg_step_loss = paddle.mean(step_loss)\n", + " loss += avg_step_loss\n", + "\n", + " loss = loss / (MAX_LEN + 1)\n", + " if(iteration % 200 == 0):\n", + " print(\"iter {}, loss:{}\".format(iteration, loss.numpy()))\n", + "\n", + " loss.backward()\n", + " opt.minimize(loss)\n", + " encoder.clear_gradients()\n", + " atten_decoder.clear_gradients()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# try the model" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/zhangjun25/Desktop/virtualenvs/venv-paddle-develop/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", + " and should_run_async(code)\n" + ] + }, + { + "ename": "NameError", + "evalue": "name 'encoder' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mencoder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0matten_decoder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mnum_of_exampels_to_evaluate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'encoder' is not defined" + ] + } + ], + "source": [ + "encoder.eval()\n", + "atten_decoder.eval()\n", + "\n", + "num_of_exampels_to_evaluate = 10\n", + "\n", + "indices = np.random.choice(len(train_en_sents), num_of_exampels_to_evaluate, replace=False)\n", + "x_data = train_en_sents[indices]\n", + "sent = paddle.to_tensor(x_data)\n", + "en_repr = encoder(sent)\n", + "\n", + "word = np.array(\n", + " [[cn_vocab['']]] * num_of_exampels_to_evaluate\n", + ")\n", + "word = paddle.to_tensor(word)\n", + "\n", + "hidden = paddle.zeros([num_of_exampels_to_evaluate, 1, hidden_size])\n", + "cell = paddle.zeros([num_of_exampels_to_evaluate, 1, hidden_size])\n", + "\n", + "decoded_sent = []\n", + "for i in range(MAX_LEN + 1):\n", + " logits, (hidden, cell) = atten_decoder(word, hidden, cell, en_repr)\n", + "\n", + " word = paddle.argmax(logits, axis=1)\n", + " decoded_sent.append(word.numpy())\n", + " word = paddle.unsqueeze(word, axis=-1)\n", + " \n", + "results = np.stack(decoded_sent, axis=1)\n", + "for i in range(num_of_exampels_to_evaluate):\n", + " en_input = \" \".join(filtered_pairs[indices[i]][0])\n", + " ground_truth_translate = \"\".join(filtered_pairs[indices[i]][1])\n", + " model_translate = \"\"\n", + " for k in results[i]:\n", + " w = list(cn_vocab)[k]\n", + " if w != '' and w != '':\n", + " model_translate += w\n", + " print(en_input)\n", + " print(\"true: {}\".format(ground_truth_translate))\n", + " print(\"pred: {}\".format(model_translate))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# The End\n", + "\n", + "have fun with Paddle." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}