From 0e509b64664e242832eaa0cae54e55cbd1c28d9f Mon Sep 17 00:00:00 2001 From: tabergma Date: Thu, 8 Nov 2018 13:34:59 +0100 Subject: [PATCH] GH-157: Ad method to generate text. --- flair/models/language_model.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/flair/models/language_model.py b/flair/models/language_model.py index e5851f3e46..c428524efa 100644 --- a/flair/models/language_model.py +++ b/flair/models/language_model.py @@ -1,3 +1,4 @@ +import logging import torch.nn as nn import torch import math @@ -6,6 +7,9 @@ from flair.data import Dictionary +log = logging.getLogger(__name__) + + class LanguageModel(nn.Module): """Container module with an encoder, a recurrent module, and a decoder.""" @@ -148,3 +152,29 @@ def save(self, file): 'dropout': self.dropout } torch.save(model_state, file, pickle_protocol=4) + + def generate_text(self, number_of_characters=1000): + log.info('Generating text ...') + + characters = [] + + idx2item = self.dictionary.idx2item + + # initial hidden state + hidden = self.init_hidden(1) + input = torch.rand(1, 1).mul(len(idx2item)).long() + + for i in range(number_of_characters): + prediction, rnn_output, hidden = self.forward(input, hidden) + word_weights = prediction.squeeze().data.div(1.0).exp().cpu() + word_idx = torch.multinomial(word_weights, 1)[0] + input.data.fill_(word_idx) + word = idx2item[word_idx].decode('UTF-8') + characters.append(word) + + if i % 100 == 0: + log.info(f'{i}/{number_of_characters} chars') + + # print generated text + log.info('Generated text:') + log.info(''.join(characters))