{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# An example of many-to-one (sequence classification)\n",
    "\n",
    "Original experiment from [Hochreiter & Schmidhuber (1997)](www.bioinf.jku.at/publications/older/2604.pdf).\n",
    "\n",
    "The goal is to classify sequences.\n",
    "Elements and targets are represented locally (input vectors with only one non-zero bit).\n",
    "The sequence starts with an `B`, ends with a `E` (the “trigger symbol”), and otherwise consists of randomly chosen symbols from the set `{a, b, c, d}` except for two elements at positions `t1` and `t2` that are either `X` or `Y`.\n",
    "For the `DifficultyLevel.HARD` case, the sequence length is randomly chosen between `100` and `110`, `t1` is randomly chosen between `10` and `20`, and `t2` is randomly chosen between `50` and `60`.\n",
    "There are `4` sequence classes `Q`, `R`, `S`, and `U`, which depend on the temporal order of `X` and `Y`.\n",
    "\n",
    "The rules are:\n",
    "\n",
    "```\n",
    "X, X -> Q,\n",
    "X, Y -> R,\n",
    "Y, X -> S,\n",
    "Y, Y -> U.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Dataset Exploration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from res.sequential_tasks import TemporalOrderExp6aSequence as QRSU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a data generator\n",
    "example_generator = QRSU.get_predefined_generator(\n",
    "    difficulty_level=QRSU.DifficultyLevel.EASY,\n",
    "    batch_size=32,\n",
    ")\n",
    "\n",
    "example_batch = example_generator[1]\n",
    "print(f'The return type is a {type(example_batch)} with length {len(example_batch)}.')\n",
    "print(f'The first item in the tuple is the batch of sequences with shape {example_batch[0].shape}.')\n",
    "print(f'The first element in the batch of sequences is:\\n {example_batch[0][0, :, :]}')\n",
    "print(f'The second item in the tuple is the corresponding batch of class labels with shape {example_batch[1].shape}.')\n",
    "print(f'The first element in the batch of class labels is:\\n {example_batch[1][0, :]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Decoding the first sequence\n",
    "sequence_decoded = example_generator.decode_x(example_batch[0][0, :, :])\n",
    "print(f'The sequence is: {sequence_decoded}')\n",
    "\n",
    "# Decoding the class label of the first sequence\n",
    "class_label_decoded = example_generator.decode_y(example_batch[1][0])\n",
    "print(f'The class label is: {class_label_decoded}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Defining the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "# Set the random seed for reproducible results\n",
    "torch.manual_seed(1)\n",
    "\n",
    "class SimpleRNN(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size):\n",
    "        # This just calls the base class constructor\n",
    "        super().__init__()\n",
    "        # Neural network layers assigned as attributes of a Module subclass\n",
    "        # have their parameters registered for training automatically.\n",
    "        self.rnn = torch.nn.RNN(input_size, hidden_size, nonlinearity='relu', batch_first=True)\n",
    "        self.linear = torch.nn.Linear(hidden_size, output_size)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # The RNN also returns its hidden state but we don't use it.\n",
    "        # While the RNN can also take a hidden state as input, the RNN\n",
    "        # gets passed a hidden state initialized with zeros by default.\n",
    "        h = self.rnn(x)[0]\n",
    "        x = self.linear(h)\n",
    "        return x\n",
    "\n",
    "class SimpleLSTM(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size):\n",
    "        super().__init__()\n",
    "        self.lstm = torch.nn.LSTM(input_size, hidden_size, batch_first=True)\n",
    "        self.linear = torch.nn.Linear(hidden_size, output_size)\n",
    "\n",
    "    def forward(self, x):\n",
    "        h = self.lstm(x)[0]\n",
    "        x = self.linear(h)\n",
    "        return x\n",
    "    \n",
    "    def get_states_across_time(self, x):\n",
    "        h_c = None\n",
    "        h_list, c_list = list(), list()\n",
    "        with torch.no_grad():\n",
    "            for t in range(x.size(1)):\n",
    "                h_c = self.lstm(x[:, [t], :], h_c)[1]\n",
    "                h_list.append(h_c[0])\n",
    "                c_list.append(h_c[1])\n",
    "            h = torch.cat(h_list)\n",
    "            c = torch.cat(c_list)\n",
    "        return h, c"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Defining the Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, train_data_gen, criterion, optimizer, device):\n",
    "    # Set the model to training mode. This will turn on layers that would\n",
    "    # otherwise behave differently during evaluation, such as dropout.\n",
    "    model.train()\n",
    "\n",
    "    # Store the number of sequences that were classified correctly\n",
    "    num_correct = 0\n",
    "\n",
    "    # Iterate over every batch of sequences. Note that the length of a data generator\n",
    "    # is defined as the number of batches required to produce a total of roughly 1000\n",
    "    # sequences given a batch size.\n",
    "    for batch_idx in range(len(train_data_gen)):\n",
    "\n",
    "        # Request a batch of sequences and class labels, convert them into tensors\n",
    "        # of the correct type, and then send them to the appropriate device.\n",
    "        data, target = train_data_gen[batch_idx]\n",
    "        data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device)\n",
    "\n",
    "        # Perform the forward pass of the model\n",
    "        output = model(data)  # Step ①\n",
    "\n",
    "        # Pick only the output corresponding to last sequence element (input is pre padded)\n",
    "        output = output[:, -1, :]\n",
    "\n",
    "        # Compute the value of the loss for this batch. For loss functions like CrossEntropyLoss,\n",
    "        # the second argument is actually expected to be a tensor of class indices rather than\n",
    "        # one-hot encoded class labels. One approach is to take advantage of the one-hot encoding\n",
    "        # of the target and call argmax along its second dimension to create a tensor of shape\n",
    "        # (batch_size) containing the index of the class label that was hot for each sequence.\n",
    "        target = target.argmax(dim=1)\n",
    "\n",
    "        loss = criterion(output, target)  # Step ②\n",
    "\n",
    "        # Clear the gradient buffers of the optimized parameters.\n",
    "        # Otherwise, gradients from the previous batch would be accumulated.\n",
    "        optimizer.zero_grad()  # Step ③\n",
    "\n",
    "        loss.backward()  # Step ④\n",
    "\n",
    "        optimizer.step()  # Step ⑤\n",
    "\n",
    "        y_pred = output.argmax(dim=1)\n",
    "        num_correct += (y_pred == target).sum().item()\n",
    "\n",
    "    return num_correct, loss.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Defining the Testing Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, test_data_gen, criterion, device):\n",
    "    # Set the model to evaluation mode. This will turn off layers that would\n",
    "    # otherwise behave differently during training, such as dropout.\n",
    "    model.eval()\n",
    "\n",
    "    # Store the number of sequences that were classified correctly\n",
    "    num_correct = 0\n",
    "\n",
    "    # A context manager is used to disable gradient calculations during inference\n",
    "    # to reduce memory usage, as we typically don't need the gradients at this point.\n",
    "    with torch.no_grad():\n",
    "        for batch_idx in range(len(test_data_gen)):\n",
    "            data, target = test_data_gen[batch_idx]\n",
    "            data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device)\n",
    "\n",
    "            output = model(data)\n",
    "            # Pick only the output corresponding to last sequence element (input is pre padded)\n",
    "            output = output[:, -1, :]\n",
    "\n",
    "            target = target.argmax(dim=1)\n",
    "            loss = criterion(output, target)\n",
    "\n",
    "            y_pred = output.argmax(dim=1)\n",
    "            num_correct += (y_pred == target).sum().item()\n",
    "\n",
    "    return num_correct, loss.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Putting it All Together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from res.plot_lib import set_default, plot_state, print_colourbar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_default()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs, verbose=True):\n",
    "    # Automatically determine the device that PyTorch should use for computation\n",
    "    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "    # Move model to the device which will be used for train and test\n",
    "    model.to(device)\n",
    "\n",
    "    # Track the value of the loss function and model accuracy across epochs\n",
    "    history_train = {'loss': [], 'acc': []}\n",
    "    history_test = {'loss': [], 'acc': []}\n",
    "\n",
    "    for epoch in range(max_epochs):\n",
    "        # Run the training loop and calculate the accuracy.\n",
    "        # Remember that the length of a data generator is the number of batches,\n",
    "        # so we multiply it by the batch size to recover the total number of sequences.\n",
    "        num_correct, loss = train(model, train_data_gen, criterion, optimizer, device)\n",
    "        accuracy = float(num_correct) / (len(train_data_gen) * train_data_gen.batch_size) * 100\n",
    "        history_train['loss'].append(loss)\n",
    "        history_train['acc'].append(accuracy)\n",
    "\n",
    "        # Do the same for the testing loop\n",
    "        num_correct, loss = test(model, test_data_gen, criterion, device)\n",
    "        accuracy = float(num_correct) / (len(test_data_gen) * test_data_gen.batch_size) * 100\n",
    "        history_test['loss'].append(loss)\n",
    "        history_test['acc'].append(accuracy)\n",
    "\n",
    "        if verbose or epoch + 1 == max_epochs:\n",
    "            print(f'[Epoch {epoch + 1}/{max_epochs}]'\n",
    "                  f\" loss: {history_train['loss'][-1]:.4f}, acc: {history_train['acc'][-1]:2.2f}%\"\n",
    "                  f\" - test_loss: {history_test['loss'][-1]:.4f}, test_acc: {history_test['acc'][-1]:2.2f}%\")\n",
    "\n",
    "    # Generate diagnostic plots for the loss and accuracy\n",
    "    fig, axes = plt.subplots(ncols=2, figsize=(9, 4.5))\n",
    "    for ax, metric in zip(axes, ['loss', 'acc']):\n",
    "        ax.plot(history_train[metric])\n",
    "        ax.plot(history_test[metric])\n",
    "        ax.set_xlabel('epoch', fontsize=12)\n",
    "        ax.set_ylabel(metric, fontsize=12)\n",
    "        ax.legend(['Train', 'Test'], loc='best')\n",
    "    plt.show()\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Simple RNN: 10 Epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup the training and test data generators\n",
    "difficulty     = QRSU.DifficultyLevel.EASY\n",
    "batch_size     = 32\n",
    "train_data_gen = QRSU.get_predefined_generator(difficulty, batch_size)\n",
    "test_data_gen  = QRSU.get_predefined_generator(difficulty, batch_size)\n",
    "\n",
    "# Setup the RNN and training settings\n",
    "input_size  = train_data_gen.n_symbols\n",
    "hidden_size = 4\n",
    "output_size = train_data_gen.n_classes\n",
    "model       = SimpleRNN(input_size, hidden_size, output_size)\n",
    "criterion   = torch.nn.CrossEntropyLoss()\n",
    "optimizer   = torch.optim.RMSprop(model.parameters(), lr=0.001)\n",
    "max_epochs  = 10\n",
    "\n",
    "# Train the model\n",
    "model = train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for parameter_group in list(model.parameters()):\n",
    "    print(parameter_group.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Simple LSTM: 10 Epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup the training and test data generators\n",
    "difficulty     = QRSU.DifficultyLevel.EASY\n",
    "batch_size     = 32\n",
    "train_data_gen = QRSU.get_predefined_generator(difficulty, batch_size)\n",
    "test_data_gen  = QRSU.get_predefined_generator(difficulty, batch_size)\n",
    "\n",
    "# Setup the RNN and training settings\n",
    "input_size  = train_data_gen.n_symbols\n",
    "hidden_size = 4\n",
    "output_size = train_data_gen.n_classes\n",
    "model       = SimpleLSTM(input_size, hidden_size, output_size)\n",
    "criterion   = torch.nn.CrossEntropyLoss()\n",
    "optimizer   = torch.optim.RMSprop(model.parameters(), lr=0.001)\n",
    "max_epochs  = 10\n",
    "\n",
    "# Train the model\n",
    "model = train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for parameter_group in list(model.parameters()):\n",
    "    print(parameter_group.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. RNN: Increasing Epoch to 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup the training and test data generators\n",
    "difficulty     = QRSU.DifficultyLevel.EASY\n",
    "batch_size     = 32\n",
    "train_data_gen = QRSU.get_predefined_generator(difficulty, batch_size)\n",
    "test_data_gen  = QRSU.get_predefined_generator(difficulty, batch_size)\n",
    "\n",
    "# Setup the RNN and training settings\n",
    "input_size  = train_data_gen.n_symbols\n",
    "hidden_size = 4\n",
    "output_size = train_data_gen.n_classes\n",
    "model       = SimpleRNN(input_size, hidden_size, output_size)\n",
    "criterion   = torch.nn.CrossEntropyLoss()\n",
    "optimizer   = torch.optim.RMSprop(model.parameters(), lr=0.001)\n",
    "max_epochs  = 100\n",
    "\n",
    "# Train the model\n",
    "model = train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs, verbose=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LSTM: Increasing Epoch to 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup the training and test data generators\n",
    "difficulty     = QRSU.DifficultyLevel.EASY\n",
    "batch_size     = 32\n",
    "train_data_gen = QRSU.get_predefined_generator(difficulty, batch_size)\n",
    "test_data_gen  = QRSU.get_predefined_generator(difficulty, batch_size)\n",
    "\n",
    "# Setup the RNN and training settings\n",
    "input_size  = train_data_gen.n_symbols\n",
    "hidden_size = 4\n",
    "output_size = train_data_gen.n_classes\n",
    "model       = SimpleLSTM(input_size, hidden_size, output_size)\n",
    "criterion   = torch.nn.CrossEntropyLoss()\n",
    "optimizer   = torch.optim.RMSprop(model.parameters(), lr=0.001)\n",
    "max_epochs  = 100\n",
    "\n",
    "# Train the model\n",
    "model = train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs, verbose=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Model Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import random\n",
    "\n",
    "def evaluate_model(model, difficulty, seed=9001, verbose=False):\n",
    "    # Define a dictionary that maps class indices to labels\n",
    "    class_idx_to_label = {0: 'Q', 1: 'R', 2: 'S', 3: 'U'}\n",
    "\n",
    "    # Create a new data generator\n",
    "    data_generator = QRSU.get_predefined_generator(difficulty, seed=seed)\n",
    "\n",
    "    # Track the number of times a class appears\n",
    "    count_classes = collections.Counter()\n",
    "\n",
    "    # Keep correctly classified and misclassified sequences, and their\n",
    "    # true and predicted class labels, for diagnostic information.\n",
    "    correct = []\n",
    "    incorrect = []\n",
    "\n",
    "    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch_idx in range(len(data_generator)):\n",
    "            data, target = test_data_gen[batch_idx]\n",
    "            data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device)\n",
    "\n",
    "            data_decoded = data_generator.decode_x_batch(data.cpu().numpy())\n",
    "            target_decoded = data_generator.decode_y_batch(target.cpu().numpy())\n",
    "\n",
    "            output = model(data)\n",
    "            output = output[:, -1, :]\n",
    "\n",
    "            target = target.argmax(dim=1)\n",
    "            y_pred = output.argmax(dim=1)\n",
    "            y_pred_decoded = [class_idx_to_label[y.item()] for y in y_pred]\n",
    "\n",
    "            count_classes.update(target_decoded)\n",
    "            for i, (truth, prediction) in enumerate(zip(target_decoded, y_pred_decoded)):\n",
    "                if truth == prediction:\n",
    "                    correct.append((data_decoded[i], truth, prediction))\n",
    "                else:\n",
    "                    incorrect.append((data_decoded[i], truth, prediction))\n",
    "\n",
    "    num_sequences = sum(count_classes.values())\n",
    "    accuracy = float(len(correct)) / num_sequences * 100\n",
    "    print(f'The accuracy of the model is measured to be {accuracy:.2f}%.\\n')\n",
    "\n",
    "    # Report the accuracy by class\n",
    "    for label in sorted(count_classes):\n",
    "        num_correct = sum(1 for _, truth, _ in correct if truth == label)\n",
    "        print(f'{label}: {num_correct} / {count_classes[label]} correct')\n",
    "\n",
    "    # Report some random sequences for examination\n",
    "    print('\\nHere are some example sequences:')\n",
    "    for i in range(10):\n",
    "        sequence, truth, prediction = correct[random.randrange(0, 10)]\n",
    "        print(f'{sequence} -> {truth} was labelled {prediction}')\n",
    "\n",
    "    # Report misclassified sequences for investigation\n",
    "    if incorrect and verbose:\n",
    "        print('\\nThe following sequences were misclassified:')\n",
    "        for sequence, truth, prediction in incorrect:\n",
    "            print(f'{sequence} -> {truth} was labelled {prediction}')\n",
    "    else:\n",
    "        print('\\nThere were no misclassified sequences.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate_model(model, difficulty)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Visualize LSTM\n",
    "Setting difficulty to `MODERATE` and `hidden_size` to 12."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# For reproducibility\n",
    "torch.manual_seed(1)\n",
    "\n",
    "# Setup the training and test data generators\n",
    "difficulty     = QRSU.DifficultyLevel.MODERATE\n",
    "batch_size     = 32\n",
    "train_data_gen = QRSU.get_predefined_generator(difficulty, batch_size)\n",
    "test_data_gen  = QRSU.get_predefined_generator(difficulty, batch_size)\n",
    "\n",
    "# Setup the RNN and training settings\n",
    "input_size  = train_data_gen.n_symbols\n",
    "hidden_size = 12\n",
    "output_size = train_data_gen.n_classes\n",
    "model       = SimpleLSTM(input_size, hidden_size, output_size)\n",
    "criterion   = torch.nn.CrossEntropyLoss()\n",
    "optimizer   = torch.optim.RMSprop(model.parameters(), lr=0.001)\n",
    "max_epochs  = 100\n",
    "\n",
    "# Train the model\n",
    "model = train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get hidden (H) and cell (C) batch state given a batch input (X)\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    data = test_data_gen[0][0]\n",
    "    X = torch.from_numpy(data).float().to(device)\n",
    "    H_t, C_t = model.get_states_across_time(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Color range is as follows:\")\n",
    "print_colourbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_state(X.cpu(), C_t, b=9, decoder=test_data_gen.decode_x)  # 3, 6, 9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_state(X.cpu(), H_t, b=9, decoder=test_data_gen.decode_x)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "encoding": "# -*- coding: utf-8 -*-"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}