diff --git a/beginner_source/text_sentiment_ngrams_tutorial.py b/beginner_source/text_sentiment_ngrams_tutorial.py index d842a05832..67108e6987 100644 --- a/beginner_source/text_sentiment_ngrams_tutorial.py +++ b/beginner_source/text_sentiment_ngrams_tutorial.py @@ -1,66 +1,133 @@ """ -Text Classification with TorchText +Text classification with the torchtext library ================================== -This tutorial shows how to use the text classification datasets -in ``torchtext``, including +In this tutorial, we will show how to use the torchtext library to build the dataset for the text classification analysis. Users will have the flexibility to -:: + - Access to the raw data as an iterator + - Build data processing pipeline to convert the raw text strings into ``torch.Tensor`` that can be used to train the model + - Shuffle and iterate the data with `torch.utils.data.DataLoader `__ +""" - - AG_NEWS, - - SogouNews, - - DBpedia, - - YelpReviewPolarity, - - YelpReviewFull, - - YahooAnswers, - - AmazonReviewPolarity, - - AmazonReviewFull -This example shows how to train a supervised learning algorithm for -classification using one of these ``TextClassification`` datasets. +###################################################################### +# Access to the raw dataset iterators +# ----------------------------------- +# +# The torchtext library provides a few raw dataset iterators, which yield the raw text strings. For example, the ``AG_NEWS`` dataset iterators yield the raw data as a tuple of label and text. -Load data with ngrams ---------------------- +import torch +from torchtext.datasets import AG_NEWS +train_iter = AG_NEWS(split='train') -A bag of ngrams feature is applied to capture some partial information -about the local word order. In practice, bi-gram or tri-gram are applied -to provide more benefits as word groups than only one word. An example: -:: +###################################################################### +# :: +# +# next(train_iter) +# >>> (3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - +# Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green +# again.") +# +# next(train_iter) +# >>> (3, 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private +# investment firm Carlyle Group,\\which has a reputation for making well-timed +# and occasionally\\controversial plays in the defense industry, has quietly +# placed\\its bets on another part of the market.') +# +# next(train_iter) +# >>> (3, "Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring +# crude prices plus worries\\about the economy and the outlook for earnings are +# expected to\\hang over the stock market next week during the depth of +# the\\summer doldrums.") +# - "load data with ngrams" - Bi-grams results: "load data", "data with", "with ngrams" - Tri-grams results: "load data with", "data with ngrams" -``TextClassification`` Dataset supports the ngrams method. By setting -ngrams to 2, the example text in the dataset will be a list of single -words plus bi-grams string. +###################################################################### +# Prepare data processing pipelines +# --------------------------------- +# +# We have revisited the very basic components of the torchtext library, including vocab, word vectors, tokenizer. Those are the basic data processing building blocks for raw text string. +# +# Here is an example for typical NLP data processing with tokenizer and vocabulary. The first step is to build a vocabulary with the raw training dataset. Users can have a customized vocab by setting up arguments in the constructor of the Vocab class. For example, the minimum frequency ``min_freq`` for the tokens to be included. -""" -import torch -import torchtext -from torchtext.datasets import text_classification -NGRAMS = 2 -import os -if not os.path.isdir('./.data'): - os.mkdir('./.data') -train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS']( - root='./.data', ngrams=NGRAMS, vocab=None) -BATCH_SIZE = 16 +from torchtext.data.utils import get_tokenizer +from collections import Counter +from torchtext.vocab import Vocab + +tokenizer = get_tokenizer('basic_english') +train_iter = AG_NEWS(split='train') +counter = Counter() +for (label, line) in train_iter: + counter.update(tokenizer(line)) +vocab = Vocab(counter, min_freq=1) + + +###################################################################### +# The vocabulary block converts a list of tokens into integers. +# +# :: +# +# [vocab[token] for token in ['here', 'is', 'an', 'example']] +# >>> [476, 22, 31, 5298] +# +# Prepare the text processing pipeline with the tokenizer and vocabulary. The text and label pipelines will be used to process the raw data strings from the dataset iterators. + +text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)] +label_pipeline = lambda x: int(x) - 1 + + +###################################################################### +# The text pipeline converts a text string into a list of integers based on the lookup table defined in the vocabulary. The label pipeline converts the label into integers. For example, +# +# :: +# +# text_pipeline('here is the an example') +# >>> [475, 21, 2, 30, 5286] +# label_pipeline('10') +# >>> 9 +# + + + +###################################################################### +# Generate data batch and iterator +# -------------------------------- +# +# `torch.utils.data.DataLoader `__ +# is recommended for PyTorch users (a tutorial is `here `__). +# It works with a map-style dataset that implements the ``getitem()`` and ``len()`` protocols, and represents a map from indices/keys to data samples. It also works with an iterable datasets with the shuffle argumnent of ``False``. +# +# Before sending to the model, ``collate_fn`` function works on a batch of samples generated from ``DataLoader``. The input to ``collate_fn`` is a batch of data with the batch size in ``DataLoader``, and ``collate_fn`` processes them according to the data processing pipelines declared previouly. Pay attention here and make sure that ``collate_fn`` is declared as a top level def. This ensures that the function is available in each worker. +# +# In this example, the text entries in the original data batch input are packed into a list and concatenated as a single tensor for the input of ``nn.EmbeddingBag``. The offset is a tensor of delimiters to represent the beginning index of the individual sequence in the text tensor. Label is a tensor saving the labels of indidividual text entries. + + +from torch.utils.data import DataLoader device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +def collate_batch(batch): + label_list, text_list, offsets = [], [], [0] + for (_label, _text) in batch: + label_list.append(label_pipeline(_label)) + processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64) + text_list.append(processed_text) + offsets.append(processed_text.size(0)) + label_list = torch.tensor(label_list, dtype=torch.int64) + offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) + text_list = torch.cat(text_list) + return label_list.to(device), text_list.to(device), offsets.to(device) + +train_iter = AG_NEWS(split='train') +dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch) + ###################################################################### # Define the model # ---------------- # -# The model is composed of the -# `EmbeddingBag `__ -# layer and the linear layer (see the figure below). ``nn.EmbeddingBag`` -# computes the mean value of a “bag” of embeddings. The text entries here -# have different lengths. ``nn.EmbeddingBag`` requires no padding here -# since the text lengths are saved in offsets. +# The model is composed of the `nn.EmbeddingBag `__ layer plus a linear layer for the classification purpose. ``nn.EmbeddingBag`` with the default mode of "mean" computes the mean value of a “bag” of embeddings. Although the text entries here have different lengths, nn.EmbeddingBag module requires no padding here since the text lengths are saved in offsets. # # Additionally, since ``nn.EmbeddingBag`` accumulates the average across # the embeddings on the fly, ``nn.EmbeddingBag`` can enhance the @@ -69,11 +136,12 @@ # .. image:: ../_static/img/text_sentiment_ngrams_model.png # -import torch.nn as nn -import torch.nn.functional as F -class TextSentiment(nn.Module): +from torch import nn + +class TextClassificationModel(nn.Module): + def __init__(self, vocab_size, embed_dim, num_class): - super().__init__() + super(TextClassificationModel, self).__init__() self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True) self.fc = nn.Linear(embed_dim, num_class) self.init_weights() @@ -93,8 +161,7 @@ def forward(self, text, offsets): # Initiate an instance # -------------------- # -# The AG_NEWS dataset has four labels and therefore the number of classes -# is four. +# The ``AG_NEWS`` dataset has four labels and therefore the number of classes is four. # # :: # @@ -103,51 +170,14 @@ def forward(self, text, offsets): # 3 : Business # 4 : Sci/Tec # -# The vocab size is equal to the length of vocab (including single word -# and ngrams). The number of classes is equal to the number of labels, -# which is four in AG_NEWS case. +# We build a model with the embedding dimension of 64. The vocab size is equal to the length of the vocabulary instance. The number of classes is equal to the number of labels, # -VOCAB_SIZE = len(train_dataset.get_vocab()) -EMBED_DIM = 32 -NUN_CLASS = len(train_dataset.get_labels()) -model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device) - - -###################################################################### -# Functions used to generate batch -# -------------------------------- -# - - -###################################################################### -# Since the text entries have different lengths, a custom function -# generate_batch() is used to generate data batches and offsets. The -# function is passed to ``collate_fn`` in ``torch.utils.data.DataLoader``. -# The input to ``collate_fn`` is a list of tensors with the size of -# batch_size, and the ``collate_fn`` function packs them into a -# mini-batch. Pay attention here and make sure that ``collate_fn`` is -# declared as a top level def. This ensures that the function is available -# in each worker. -# -# The text entries in the original data batch input are packed into a list -# and concatenated as a single tensor as the input of ``nn.EmbeddingBag``. -# The offsets is a tensor of delimiters to represent the beginning index -# of the individual sequence in the text tensor. Label is a tensor saving -# the labels of individual text entries. -# - -def generate_batch(batch): - label = torch.tensor([entry[0] for entry in batch]) - text = [entry[1] for entry in batch] - offsets = [0] + [len(entry) for entry in text] - # torch.Tensor.cumsum returns the cumulative sum - # of elements in the dimension dim. - # torch.Tensor([1.0, 2.0, 3.0]).cumsum(dim=0) - - offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) - text = torch.cat(text) - return text, offsets, label +train_iter = AG_NEWS(split='train') +num_class = len(set([label for (label, text) in train_iter])) +vocab_size = len(vocab) +emsize = 64 +model = TextClassificationModel(vocab_size, emsize, num_class).to(device) ###################################################################### @@ -156,144 +186,170 @@ def generate_batch(batch): # -###################################################################### -# `torch.utils.data.DataLoader `__ -# is recommended for PyTorch users, and it makes data loading in parallel -# easily (a tutorial is -# `here `__). -# We use ``DataLoader`` here to load AG_NEWS datasets and send it to the -# model for training/validation. -# - -from torch.utils.data import DataLoader +import time -def train_func(sub_train_): +def train(dataloader): + model.train() + total_acc, total_count = 0, 0 + log_interval = 500 + start_time = time.time() - # Train the model - train_loss = 0 - train_acc = 0 - data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True, - collate_fn=generate_batch) - for i, (text, offsets, cls) in enumerate(data): + for idx, (label, text, offsets) in enumerate(dataloader): optimizer.zero_grad() - text, offsets, cls = text.to(device), offsets.to(device), cls.to(device) - output = model(text, offsets) - loss = criterion(output, cls) - train_loss += loss.item() + predited_label = model(text, offsets) + loss = criterion(predited_label, label) loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) optimizer.step() - train_acc += (output.argmax(1) == cls).sum().item() - - # Adjust the learning rate - scheduler.step() + total_acc += (predited_label.argmax(1) == label).sum().item() + total_count += label.size(0) + if idx % log_interval == 0 and idx > 0: + elapsed = time.time() - start_time + print('| epoch {:3d} | {:5d}/{:5d} batches ' + '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader), + total_acc/total_count)) + total_acc, total_count = 0, 0 + start_time = time.time() + +def evaluate(dataloader): + model.eval() + total_acc, total_count = 0, 0 - return train_loss / len(sub_train_), train_acc / len(sub_train_) - -def test(data_): - loss = 0 - acc = 0 - data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch) - for text, offsets, cls in data: - text, offsets, cls = text.to(device), offsets.to(device), cls.to(device) - with torch.no_grad(): - output = model(text, offsets) - loss = criterion(output, cls) - loss += loss.item() - acc += (output.argmax(1) == cls).sum().item() - - return loss / len(data_), acc / len(data_) + with torch.no_grad(): + for idx, (label, text, offsets) in enumerate(dataloader): + predited_label = model(text, offsets) + loss = criterion(predited_label, label) + total_acc += (predited_label.argmax(1) == label).sum().item() + total_count += label.size(0) + return total_acc/total_count ###################################################################### # Split the dataset and run the model # ----------------------------------- # -# Since the original AG_NEWS has no valid dataset, we split the training +# Since the original ``AG_NEWS`` has no valid dataset, we split the training # dataset into train/valid sets with a split ratio of 0.95 (train) and # 0.05 (valid). Here we use # `torch.utils.data.dataset.random_split `__ # function in PyTorch core library. # # `CrossEntropyLoss `__ -# criterion combines nn.LogSoftmax() and nn.NLLLoss() in a single class. +# criterion combines ``nn.LogSoftmax()`` and ``nn.NLLLoss()`` in a single class. # It is useful when training a classification problem with C classes. # `SGD `__ -# implements stochastic gradient descent method as optimizer. The initial -# learning rate is set to 4.0. +# implements stochastic gradient descent method as the optimizer. The initial +# learning rate is set to 5.0. # `StepLR `__ # is used here to adjust the learning rate through epochs. # -import time -from torch.utils.data.dataset import random_split -N_EPOCHS = 5 -min_valid_loss = float('inf') - -criterion = torch.nn.CrossEntropyLoss().to(device) -optimizer = torch.optim.SGD(model.parameters(), lr=4.0) -scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9) - -train_len = int(len(train_dataset) * 0.95) -sub_train_, sub_valid_ = \ - random_split(train_dataset, [train_len, len(train_dataset) - train_len]) - -for epoch in range(N_EPOCHS): - - start_time = time.time() - train_loss, train_acc = train_func(sub_train_) - valid_loss, valid_acc = test(sub_valid_) - secs = int(time.time() - start_time) - mins = secs / 60 - secs = secs % 60 - - print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs)) - print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)') - print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)') +from torch.utils.data.dataset import random_split +# Hyperparameters +EPOCHS = 10 # epoch +LR = 5 # learning rate +BATCH_SIZE = 64 # batch size for training + +criterion = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.SGD(model.parameters(), lr=LR) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1) +total_accu = None +train_iter, test_iter = AG_NEWS() +train_dataset = list(train_iter) +test_dataset = list(test_iter) +num_train = int(len(train_dataset) * 0.95) +split_train_, split_valid_ = \ + random_split(train_dataset, [num_train, len(train_dataset) - num_train]) + +train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, + shuffle=True, collate_fn=collate_batch) +valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, + shuffle=True, collate_fn=collate_batch) +test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, + shuffle=True, collate_fn=collate_batch) + +for epoch in range(1, EPOCHS + 1): + epoch_start_time = time.time() + train(train_dataloader) + accu_val = evaluate(valid_dataloader) + if total_accu is not None and total_accu > accu_val: + scheduler.step() + else: + total_accu = accu_val + print('-' * 59) + print('| end of epoch {:3d} | time: {:5.2f}s | ' + 'valid accuracy {:8.3f} '.format(epoch, + time.time() - epoch_start_time, + accu_val)) + print('-' * 59) ###################################################################### -# Running the model on GPU with the following information: -# -# Epoch: 1 \| time in 0 minutes, 11 seconds +# Running the model on GPU with the following printout: # # :: # -# Loss: 0.0263(train) | Acc: 84.5%(train) -# Loss: 0.0001(valid) | Acc: 89.0%(valid) -# -# -# Epoch: 2 \| time in 0 minutes, 10 seconds -# -# :: -# -# Loss: 0.0119(train) | Acc: 93.6%(train) -# Loss: 0.0000(valid) | Acc: 89.6%(valid) -# -# -# Epoch: 3 \| time in 0 minutes, 9 seconds -# -# :: -# -# Loss: 0.0069(train) | Acc: 96.4%(train) -# Loss: 0.0000(valid) | Acc: 90.5%(valid) -# -# -# Epoch: 4 \| time in 0 minutes, 11 seconds -# -# :: -# -# Loss: 0.0038(train) | Acc: 98.2%(train) -# Loss: 0.0000(valid) | Acc: 90.4%(valid) -# -# -# Epoch: 5 \| time in 0 minutes, 11 seconds -# -# :: -# -# Loss: 0.0022(train) | Acc: 99.0%(train) -# Loss: 0.0000(valid) | Acc: 91.0%(valid) -# +# | epoch 1 | 500/ 1782 batches | accuracy 0.684 +# | epoch 1 | 1000/ 1782 batches | accuracy 0.852 +# | epoch 1 | 1500/ 1782 batches | accuracy 0.877 +# ----------------------------------------------------------- +# | end of epoch 1 | time: 8.33s | valid accuracy 0.867 +# ----------------------------------------------------------- +# | epoch 2 | 500/ 1782 batches | accuracy 0.895 +# | epoch 2 | 1000/ 1782 batches | accuracy 0.900 +# | epoch 2 | 1500/ 1782 batches | accuracy 0.903 +# ----------------------------------------------------------- +# | end of epoch 2 | time: 8.18s | valid accuracy 0.890 +# ----------------------------------------------------------- +# | epoch 3 | 500/ 1782 batches | accuracy 0.914 +# | epoch 3 | 1000/ 1782 batches | accuracy 0.914 +# | epoch 3 | 1500/ 1782 batches | accuracy 0.916 +# ----------------------------------------------------------- +# | end of epoch 3 | time: 8.20s | valid accuracy 0.897 +# ----------------------------------------------------------- +# | epoch 4 | 500/ 1782 batches | accuracy 0.926 +# | epoch 4 | 1000/ 1782 batches | accuracy 0.924 +# | epoch 4 | 1500/ 1782 batches | accuracy 0.921 +# ----------------------------------------------------------- +# | end of epoch 4 | time: 8.18s | valid accuracy 0.895 +# ----------------------------------------------------------- +# | epoch 5 | 500/ 1782 batches | accuracy 0.938 +# | epoch 5 | 1000/ 1782 batches | accuracy 0.935 +# | epoch 5 | 1500/ 1782 batches | accuracy 0.937 +# ----------------------------------------------------------- +# | end of epoch 5 | time: 8.16s | valid accuracy 0.902 +# ----------------------------------------------------------- +# | epoch 6 | 500/ 1782 batches | accuracy 0.939 +# | epoch 6 | 1000/ 1782 batches | accuracy 0.939 +# | epoch 6 | 1500/ 1782 batches | accuracy 0.938 +# ----------------------------------------------------------- +# | end of epoch 6 | time: 8.16s | valid accuracy 0.906 +# ----------------------------------------------------------- +# | epoch 7 | 500/ 1782 batches | accuracy 0.941 +# | epoch 7 | 1000/ 1782 batches | accuracy 0.939 +# | epoch 7 | 1500/ 1782 batches | accuracy 0.939 +# ----------------------------------------------------------- +# | end of epoch 7 | time: 8.19s | valid accuracy 0.903 +# ----------------------------------------------------------- +# | epoch 8 | 500/ 1782 batches | accuracy 0.942 +# | epoch 8 | 1000/ 1782 batches | accuracy 0.941 +# | epoch 8 | 1500/ 1782 batches | accuracy 0.942 +# ----------------------------------------------------------- +# | end of epoch 8 | time: 8.16s | valid accuracy 0.904 +# ----------------------------------------------------------- +# | epoch 9 | 500/ 1782 batches | accuracy 0.942 +# | epoch 9 | 1000/ 1782 batches | accuracy 0.941 +# | epoch 9 | 1500/ 1782 batches | accuracy 0.942 +# ----------------------------------------------------------- +# end of epoch 9 | time: 8.16s | valid accuracy 0.904 +# ----------------------------------------------------------- +# | epoch 10 | 500/ 1782 batches | accuracy 0.940 +# | epoch 10 | 1000/ 1782 batches | accuracy 0.942 +# | epoch 10 | 1500/ 1782 batches | accuracy 0.942 +# ----------------------------------------------------------- +# | end of epoch 10 | time: 8.15s | valid accuracy 0.904 +# ----------------------------------------------------------- ###################################################################### @@ -301,17 +357,20 @@ def test(data_): # ------------------------------------ # -print('Checking the results of test dataset...') -test_loss, test_acc = test(test_dataset) -print(f'\tLoss: {test_loss:.4f}(test)\t|\tAcc: {test_acc * 100:.1f}%(test)') ###################################################################### -# Checking the results of test dataset… +# Checking the results of the test dataset… + +print('Checking the results of test dataset.') +accu_test = evaluate(test_dataloader) +print('test accuracy {:8.3f}'.format(accu_test)) + +################################################ # # :: # -# Loss: 0.0237(test) | Acc: 90.5%(test) +# test accuracy 0.906 # @@ -319,25 +378,18 @@ def test(data_): # Test on a random news # --------------------- # -# Use the best model so far and test a golf news. The label information is -# available -# `here `__. +# Use the best model so far and test a golf news. # -import re -from torchtext.data.utils import ngrams_iterator -from torchtext.data.utils import get_tokenizer -ag_news_label = {1 : "World", - 2 : "Sports", - 3 : "Business", - 4 : "Sci/Tec"} +ag_news_label = {1: "World", + 2: "Sports", + 3: "Business", + 4: "Sci/Tec"} -def predict(text, model, vocab, ngrams): - tokenizer = get_tokenizer("basic_english") +def predict(text, text_pipeline): with torch.no_grad(): - text = torch.tensor([vocab[token] - for token in ngrams_iterator(tokenizer(text), ngrams)]) + text = torch.tensor(text_pipeline(text)) output = model(text, torch.tensor([0])) return output.argmax(1).item() + 1 @@ -353,17 +405,14 @@ def predict(text, model, vocab, ngrams): was even more impressive considering he’d never played the \ front nine at TPC Southwind." -vocab = train_dataset.get_vocab() model = model.to("cpu") -print("This is a %s news" %ag_news_label[predict(ex_text_str, model, vocab, 2)]) - -###################################################################### -# This is a Sports news -# +print("This is a %s news" %ag_news_label[predict(ex_text_str, text_pipeline)]) -###################################################################### -# You can find the code examples displayed in this note -# `here `__. +################################################ +# +# :: +# +# This is a Sports news #