{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# An example of many-to-one (sequence classification)\n", "\n", "Original experiment from [Hochreiter & Schmidhuber (1997)](www.bioinf.jku.at/publications/older/2604.pdf).\n", "\n", "The goal is to classify sequences.\n", "Elements and targets are represented locally (input vectors with only one non-zero bit).\n", "The sequence starts with an `B`, ends with a `E` (the “trigger symbol”), and otherwise consists of randomly chosen symbols from the set `{a, b, c, d}` except for two elements at positions `t1` and `t2` that are either `X` or `Y`.\n", "For the `DifficultyLevel.HARD` case, the sequence length is randomly chosen between `100` and `110`, `t1` is randomly chosen between `10` and `20`, and `t2` is randomly chosen between `50` and `60`.\n", "There are `4` sequence classes `Q`, `R`, `S`, and `U`, which depend on the temporal order of `X` and `Y`.\n", "\n", "The rules are:\n", "\n", "```\n", "X, X -> Q,\n", "X, Y -> R,\n", "Y, X -> S,\n", "Y, Y -> U.\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Dataset Exploration" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from res.sequential_tasks import TemporalOrderExp6aSequence as QRSU" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create a data generator\n", "example_generator = QRSU.get_predefined_generator(\n", " difficulty_level=QRSU.DifficultyLevel.EASY,\n", " batch_size=32,\n", ")\n", "\n", "example_batch = example_generator[1]\n", "print(f'The return type is a {type(example_batch)} with length {len(example_batch)}.')\n", "print(f'The first item in the tuple is the batch of sequences with shape {example_batch[0].shape}.')\n", "print(f'The first element in the batch of sequences is:\\n {example_batch[0][0, :, :]}')\n", "print(f'The second item in the tuple is the corresponding batch of class labels with shape {example_batch[1].shape}.')\n", "print(f'The first element in the batch of class labels is:\\n {example_batch[1][0, :]}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Decoding the first sequence\n", "sequence_decoded = example_generator.decode_x(example_batch[0][0, :, :])\n", "print(f'The sequence is: {sequence_decoded}')\n", "\n", "# Decoding the class label of the first sequence\n", "class_label_decoded = example_generator.decode_y(example_batch[1][0])\n", "print(f'The class label is: {class_label_decoded}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Defining the Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "# Set the random seed for reproducible results\n", "torch.manual_seed(1)\n", "\n", "class SimpleRNN(nn.Module):\n", " def __init__(self, input_size, hidden_size, output_size):\n", " # This just calls the base class constructor\n", " super().__init__()\n", " # Neural network layers assigned as attributes of a Module subclass\n", " # have their parameters registered for training automatically.\n", " self.rnn = torch.nn.RNN(input_size, hidden_size, nonlinearity='relu', batch_first=True)\n", " self.linear = torch.nn.Linear(hidden_size, output_size)\n", "\n", " def forward(self, x):\n", " # The RNN also returns its hidden state but we don't use it.\n", " # While the RNN can also take a hidden state as input, the RNN\n", " # gets passed a hidden state initialized with zeros by default.\n", " h = self.rnn(x)[0]\n", " x = self.linear(h)\n", " return x\n", "\n", "class SimpleLSTM(nn.Module):\n", " def __init__(self, input_size, hidden_size, output_size):\n", " super().__init__()\n", " self.lstm = torch.nn.LSTM(input_size, hidden_size, batch_first=True)\n", " self.linear = torch.nn.Linear(hidden_size, output_size)\n", "\n", " def forward(self, x):\n", " h = self.lstm(x)[0]\n", " x = self.linear(h)\n", " return x\n", " \n", " def get_states_across_time(self, x):\n", " h_c = None\n", " h_list, c_list = list(), list()\n", " with torch.no_grad():\n", " for t in range(x.size(1)):\n", " h_c = self.lstm(x[:, [t], :], h_c)[1]\n", " h_list.append(h_c[0])\n", " c_list.append(h_c[1])\n", " h = torch.cat(h_list)\n", " c = torch.cat(c_list)\n", " return h, c" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Defining the Training Loop" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def train(model, train_data_gen, criterion, optimizer, device):\n", " # Set the model to training mode. This will turn on layers that would\n", " # otherwise behave differently during evaluation, such as dropout.\n", " model.train()\n", "\n", " # Store the number of sequences that were classified correctly\n", " num_correct = 0\n", "\n", " # Iterate over every batch of sequences. Note that the length of a data generator\n", " # is defined as the number of batches required to produce a total of roughly 1000\n", " # sequences given a batch size.\n", " for batch_idx in range(len(train_data_gen)):\n", "\n", " # Request a batch of sequences and class labels, convert them into tensors\n", " # of the correct type, and then send them to the appropriate device.\n", " data, target = train_data_gen[batch_idx]\n", " data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device)\n", "\n", " # Perform the forward pass of the model\n", " output = model(data) # Step ①\n", "\n", " # Pick only the output corresponding to last sequence element (input is pre padded)\n", " output = output[:, -1, :]\n", "\n", " # Compute the value of the loss for this batch. For loss functions like CrossEntropyLoss,\n", " # the second argument is actually expected to be a tensor of class indices rather than\n", " # one-hot encoded class labels. One approach is to take advantage of the one-hot encoding\n", " # of the target and call argmax along its second dimension to create a tensor of shape\n", " # (batch_size) containing the index of the class label that was hot for each sequence.\n", " target = target.argmax(dim=1)\n", "\n", " loss = criterion(output, target) # Step ②\n", "\n", " # Clear the gradient buffers of the optimized parameters.\n", " # Otherwise, gradients from the previous batch would be accumulated.\n", " optimizer.zero_grad() # Step ③\n", "\n", " loss.backward() # Step ④\n", "\n", " optimizer.step() # Step ⑤\n", "\n", " y_pred = output.argmax(dim=1)\n", " num_correct += (y_pred == target).sum().item()\n", "\n", " return num_correct, loss.item()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Defining the Testing Loop" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def test(model, test_data_gen, criterion, device):\n", " # Set the model to evaluation mode. This will turn off layers that would\n", " # otherwise behave differently during training, such as dropout.\n", " model.eval()\n", "\n", " # Store the number of sequences that were classified correctly\n", " num_correct = 0\n", "\n", " # A context manager is used to disable gradient calculations during inference\n", " # to reduce memory usage, as we typically don't need the gradients at this point.\n", " with torch.no_grad():\n", " for batch_idx in range(len(test_data_gen)):\n", " data, target = test_data_gen[batch_idx]\n", " data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device)\n", "\n", " output = model(data)\n", " # Pick only the output corresponding to last sequence element (input is pre padded)\n", " output = output[:, -1, :]\n", "\n", " target = target.argmax(dim=1)\n", " loss = criterion(output, target)\n", "\n", " y_pred = output.argmax(dim=1)\n", " num_correct += (y_pred == target).sum().item()\n", "\n", " return num_correct, loss.item()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Putting it All Together" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from res.plot_lib import set_default, plot_state, print_colourbar" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "set_default()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs, verbose=True):\n", " # Automatically determine the device that PyTorch should use for computation\n", " device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", "\n", " # Move model to the device which will be used for train and test\n", " model.to(device)\n", "\n", " # Track the value of the loss function and model accuracy across epochs\n", " history_train = {'loss': [], 'acc': []}\n", " history_test = {'loss': [], 'acc': []}\n", "\n", " for epoch in range(max_epochs):\n", " # Run the training loop and calculate the accuracy.\n", " # Remember that the length of a data generator is the number of batches,\n", " # so we multiply it by the batch size to recover the total number of sequences.\n", " num_correct, loss = train(model, train_data_gen, criterion, optimizer, device)\n", " accuracy = float(num_correct) / (len(train_data_gen) * train_data_gen.batch_size) * 100\n", " history_train['loss'].append(loss)\n", " history_train['acc'].append(accuracy)\n", "\n", " # Do the same for the testing loop\n", " num_correct, loss = test(model, test_data_gen, criterion, device)\n", " accuracy = float(num_correct) / (len(test_data_gen) * test_data_gen.batch_size) * 100\n", " history_test['loss'].append(loss)\n", " history_test['acc'].append(accuracy)\n", "\n", " if verbose or epoch + 1 == max_epochs:\n", " print(f'[Epoch {epoch + 1}/{max_epochs}]'\n", " f\" loss: {history_train['loss'][-1]:.4f}, acc: {history_train['acc'][-1]:2.2f}%\"\n", " f\" - test_loss: {history_test['loss'][-1]:.4f}, test_acc: {history_test['acc'][-1]:2.2f}%\")\n", "\n", " # Generate diagnostic plots for the loss and accuracy\n", " fig, axes = plt.subplots(ncols=2, figsize=(9, 4.5))\n", " for ax, metric in zip(axes, ['loss', 'acc']):\n", " ax.plot(history_train[metric])\n", " ax.plot(history_test[metric])\n", " ax.set_xlabel('epoch', fontsize=12)\n", " ax.set_ylabel(metric, fontsize=12)\n", " ax.legend(['Train', 'Test'], loc='best')\n", " plt.show()\n", "\n", " return model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Simple RNN: 10 Epochs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Setup the training and test data generators\n", "difficulty = QRSU.DifficultyLevel.EASY\n", "batch_size = 32\n", "train_data_gen = QRSU.get_predefined_generator(difficulty, batch_size)\n", "test_data_gen = QRSU.get_predefined_generator(difficulty, batch_size)\n", "\n", "# Setup the RNN and training settings\n", "input_size = train_data_gen.n_symbols\n", "hidden_size = 4\n", "output_size = train_data_gen.n_classes\n", "model = SimpleRNN(input_size, hidden_size, output_size)\n", "criterion = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)\n", "max_epochs = 10\n", "\n", "# Train the model\n", "model = train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for parameter_group in list(model.parameters()):\n", " print(parameter_group.size())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Simple LSTM: 10 Epochs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Setup the training and test data generators\n", "difficulty = QRSU.DifficultyLevel.EASY\n", "batch_size = 32\n", "train_data_gen = QRSU.get_predefined_generator(difficulty, batch_size)\n", "test_data_gen = QRSU.get_predefined_generator(difficulty, batch_size)\n", "\n", "# Setup the RNN and training settings\n", "input_size = train_data_gen.n_symbols\n", "hidden_size = 4\n", "output_size = train_data_gen.n_classes\n", "model = SimpleLSTM(input_size, hidden_size, output_size)\n", "criterion = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)\n", "max_epochs = 10\n", "\n", "# Train the model\n", "model = train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for parameter_group in list(model.parameters()):\n", " print(parameter_group.size())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. RNN: Increasing Epoch to 100" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Setup the training and test data generators\n", "difficulty = QRSU.DifficultyLevel.EASY\n", "batch_size = 32\n", "train_data_gen = QRSU.get_predefined_generator(difficulty, batch_size)\n", "test_data_gen = QRSU.get_predefined_generator(difficulty, batch_size)\n", "\n", "# Setup the RNN and training settings\n", "input_size = train_data_gen.n_symbols\n", "hidden_size = 4\n", "output_size = train_data_gen.n_classes\n", "model = SimpleRNN(input_size, hidden_size, output_size)\n", "criterion = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)\n", "max_epochs = 100\n", "\n", "# Train the model\n", "model = train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs, verbose=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## LSTM: Increasing Epoch to 100" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Setup the training and test data generators\n", "difficulty = QRSU.DifficultyLevel.EASY\n", "batch_size = 32\n", "train_data_gen = QRSU.get_predefined_generator(difficulty, batch_size)\n", "test_data_gen = QRSU.get_predefined_generator(difficulty, batch_size)\n", "\n", "# Setup the RNN and training settings\n", "input_size = train_data_gen.n_symbols\n", "hidden_size = 4\n", "output_size = train_data_gen.n_classes\n", "model = SimpleLSTM(input_size, hidden_size, output_size)\n", "criterion = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)\n", "max_epochs = 100\n", "\n", "# Train the model\n", "model = train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs, verbose=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Model Evaluation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import collections\n", "import random\n", "\n", "def evaluate_model(model, difficulty, seed=9001, verbose=False):\n", " # Define a dictionary that maps class indices to labels\n", " class_idx_to_label = {0: 'Q', 1: 'R', 2: 'S', 3: 'U'}\n", "\n", " # Create a new data generator\n", " data_generator = QRSU.get_predefined_generator(difficulty, seed=seed)\n", "\n", " # Track the number of times a class appears\n", " count_classes = collections.Counter()\n", "\n", " # Keep correctly classified and misclassified sequences, and their\n", " # true and predicted class labels, for diagnostic information.\n", " correct = []\n", " incorrect = []\n", "\n", " device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", "\n", " model.eval()\n", "\n", " with torch.no_grad():\n", " for batch_idx in range(len(data_generator)):\n", " data, target = test_data_gen[batch_idx]\n", " data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device)\n", "\n", " data_decoded = data_generator.decode_x_batch(data.cpu().numpy())\n", " target_decoded = data_generator.decode_y_batch(target.cpu().numpy())\n", "\n", " output = model(data)\n", " output = output[:, -1, :]\n", "\n", " target = target.argmax(dim=1)\n", " y_pred = output.argmax(dim=1)\n", " y_pred_decoded = [class_idx_to_label[y.item()] for y in y_pred]\n", "\n", " count_classes.update(target_decoded)\n", " for i, (truth, prediction) in enumerate(zip(target_decoded, y_pred_decoded)):\n", " if truth == prediction:\n", " correct.append((data_decoded[i], truth, prediction))\n", " else:\n", " incorrect.append((data_decoded[i], truth, prediction))\n", "\n", " num_sequences = sum(count_classes.values())\n", " accuracy = float(len(correct)) / num_sequences * 100\n", " print(f'The accuracy of the model is measured to be {accuracy:.2f}%.\\n')\n", "\n", " # Report the accuracy by class\n", " for label in sorted(count_classes):\n", " num_correct = sum(1 for _, truth, _ in correct if truth == label)\n", " print(f'{label}: {num_correct} / {count_classes[label]} correct')\n", "\n", " # Report some random sequences for examination\n", " print('\\nHere are some example sequences:')\n", " for i in range(10):\n", " sequence, truth, prediction = correct[random.randrange(0, 10)]\n", " print(f'{sequence} -> {truth} was labelled {prediction}')\n", "\n", " # Report misclassified sequences for investigation\n", " if incorrect and verbose:\n", " print('\\nThe following sequences were misclassified:')\n", " for sequence, truth, prediction in incorrect:\n", " print(f'{sequence} -> {truth} was labelled {prediction}')\n", " else:\n", " print('\\nThere were no misclassified sequences.')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "evaluate_model(model, difficulty)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. Visualize LSTM\n", "Setting difficulty to `MODERATE` and `hidden_size` to 12." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "# For reproducibility\n", "torch.manual_seed(1)\n", "\n", "# Setup the training and test data generators\n", "difficulty = QRSU.DifficultyLevel.MODERATE\n", "batch_size = 32\n", "train_data_gen = QRSU.get_predefined_generator(difficulty, batch_size)\n", "test_data_gen = QRSU.get_predefined_generator(difficulty, batch_size)\n", "\n", "# Setup the RNN and training settings\n", "input_size = train_data_gen.n_symbols\n", "hidden_size = 12\n", "output_size = train_data_gen.n_classes\n", "model = SimpleLSTM(input_size, hidden_size, output_size)\n", "criterion = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)\n", "max_epochs = 100\n", "\n", "# Train the model\n", "model = train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs, verbose=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Get hidden (H) and cell (C) batch state given a batch input (X)\n", "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", "model.eval()\n", "with torch.no_grad():\n", " data = test_data_gen[0][0]\n", " X = torch.from_numpy(data).float().to(device)\n", " H_t, C_t = model.get_states_across_time(X)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"Color range is as follows:\")\n", "print_colourbar()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_state(X.cpu(), C_t, b=9, decoder=test_data_gen.decode_x) # 3, 6, 9" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_state(X.cpu(), H_t, b=9, decoder=test_data_gen.decode_x)" ] } ], "metadata": { "jupytext": { "encoding": "# -*- coding: utf-8 -*-" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 4 }