Skip to content

Commit

Permalink
GH-157: Ad method to generate text.
Browse files Browse the repository at this point in the history
  • Loading branch information
tabergma committed Nov 8, 2018
1 parent 219c63e commit 0e509b6
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions flair/models/language_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import torch.nn as nn
import torch
import math
Expand All @@ -6,6 +7,9 @@
from flair.data import Dictionary


log = logging.getLogger(__name__)


class LanguageModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""

Expand Down Expand Up @@ -148,3 +152,29 @@ def save(self, file):
'dropout': self.dropout
}
torch.save(model_state, file, pickle_protocol=4)

def generate_text(self, number_of_characters=1000):
log.info('Generating text ...')

characters = []

idx2item = self.dictionary.idx2item

# initial hidden state
hidden = self.init_hidden(1)
input = torch.rand(1, 1).mul(len(idx2item)).long()

for i in range(number_of_characters):
prediction, rnn_output, hidden = self.forward(input, hidden)
word_weights = prediction.squeeze().data.div(1.0).exp().cpu()
word_idx = torch.multinomial(word_weights, 1)[0]
input.data.fill_(word_idx)
word = idx2item[word_idx].decode('UTF-8')
characters.append(word)

if i % 100 == 0:
log.info(f'{i}/{number_of_characters} chars')

# print generated text
log.info('Generated text:')
log.info(''.join(characters))

0 comments on commit 0e509b6

Please sign in to comment.