diff --git a/flair/visual/__init__.py b/flair/visual/__init__.py index 25affcab43..5027694df3 100644 --- a/flair/visual/__init__.py +++ b/flair/visual/__init__.py @@ -1,2 +1,2 @@ -from .manifold import tSNE, visualize, prepare_word_embeddings, prepare_char_embeddings, word_contexts, char_contexts +from .manifold import Visualizer from .activations import Highlighter \ No newline at end of file diff --git a/flair/visual/activations.py b/flair/visual/activations.py index 33740ac81a..39c876d0f2 100644 --- a/flair/visual/activations.py +++ b/flair/visual/activations.py @@ -1,7 +1,7 @@ import numpy -class Highlighter: +class Highlighter(object): def __init__(self): self.color_map = [ diff --git a/flair/visual/manifold.py b/flair/visual/manifold.py index 6da7b23cec..bd80d7dc3f 100644 --- a/flair/visual/manifold.py +++ b/flair/visual/manifold.py @@ -3,112 +3,132 @@ import numpy -def prepare_word_embeddings(embeddings, sentences): - X = [] +class _Transform: + def __init__(self): + pass - print('computing embeddings') - for sentence in tqdm.tqdm(sentences): - embeddings.embed(sentence) + def fit(self, X): + return self.transform.fit_transform(X) - for i, token in enumerate(sentence): - X.append(token.embedding.detach().numpy()[None, :]) - X = numpy.concatenate(X, 0) +class tSNE(_Transform): + def __init__(self): + super().__init__() - return X + self.transform = \ + TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300) -def word_contexts(sentences): - contexts = [] +class Visualizer(object): - for sentence in sentences: + def visualize_word_emeddings(self, embeddings, sentences, output_file): + X = self.prepare_word_embeddings(embeddings, sentences) + contexts = self.word_contexts(sentences) - strs = [x.text for x in sentence.tokens] + trans_ = tSNE() + reduced = trans_.fit(X) - for i, token in enumerate(strs): - prop = ' {token} '.format( - token=token) + self.visualize(reduced, contexts, output_file) - prop = ' '.join(strs[max(i - 4, 0):i]) + prop - prop = prop + ' '.join(strs[i + 1:min(len(strs), i + 5)]) + def visualize_char_emeddings(self, embeddings, sentences, output_file): + X = self.prepare_char_embeddings(embeddings, sentences) + contexts = self.char_contexts(sentences) - contexts.append('

' + prop + '

') + trans_ = tSNE() + reduced = trans_.fit(X) - return contexts + self.visualize(reduced, contexts, output_file) + @staticmethod + def prepare_word_embeddings(embeddings, sentences): + X = [] -def prepare_char_embeddings(embeddings, sentences): - X = [] + print('computing embeddings') + for sentence in tqdm.tqdm(sentences): + embeddings.embed(sentence) - print('computing embeddings') - for sentence in tqdm.tqdm(sentences): + for i, token in enumerate(sentence): + X.append(token.embedding.detach().numpy()[None, :]) - sentence = ' '.join([x.text for x in sentence]) + X = numpy.concatenate(X, 0) - hidden = embeddings.lm.get_representation([sentence]) - X.append(hidden.squeeze().detach().numpy()) + return X - X = numpy.concatenate(X, 0) + @staticmethod + def word_contexts(sentences): + contexts = [] - return X + for sentence in sentences: + strs = [x.text for x in sentence.tokens] -def char_contexts(sentences): - contexts = [] + for i, token in enumerate(strs): + prop = ' {token} '.format( + token=token) - for sentence in sentences: - sentence = ' '.join([token.text for token in sentence]) + prop = ' '.join(strs[max(i - 4, 0):i]) + prop + prop = prop + ' '.join(strs[i + 1:min(len(strs), i + 5)]) - for i, char in enumerate(sentence): + contexts.append('

' + prop + '

') - context = '{}'.format(char) - context = ''.join(sentence[max(i - 30, 0):i]) + context - context = context + ''.join(sentence[i + 1:min(len(sentence), i + 30)]) + return contexts - contexts.append(context) + @staticmethod + def prepare_char_embeddings(embeddings, sentences): + X = [] - return contexts + print('computing embeddings') + for sentence in tqdm.tqdm(sentences): + sentence = ' '.join([x.text for x in sentence]) -class _Transform: - def __init__(self): - pass + hidden = embeddings.lm.get_representation([sentence]) + X.append(hidden.squeeze().detach().numpy()) - def fit(self, X): - return self.transform.fit_transform(X) + X = numpy.concatenate(X, 0) + return X -class tSNE(_Transform): - def __init__(self): + @staticmethod + def char_contexts(sentences): + contexts = [] - super().__init__() + for sentence in sentences: + sentence = ' '.join([token.text for token in sentence]) - self.transform = \ - TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300) + for i, char in enumerate(sentence): + + context = '{}'.format(char) + context = ''.join(sentence[max(i - 30, 0):i]) + context + context = context + ''.join(sentence[i + 1:min(len(sentence), i + 30)]) + + contexts.append(context) + return contexts -def visualize(X, contexts, file): - import matplotlib.pyplot - import mpld3 + @staticmethod + def visualize(X, contexts, file): + import matplotlib.pyplot + import mpld3 - fig, ax = matplotlib.pyplot.subplots() + fig, ax = matplotlib.pyplot.subplots() - ax.grid(True, alpha=0.3) + ax.grid(True, alpha=0.3) - points = ax.plot(X[:, 0], X[:, 1], 'o', color='b', - mec='k', ms=5, mew=1, alpha=.6) + points = ax.plot(X[:, 0], X[:, 1], 'o', color='b', + mec='k', ms=5, mew=1, alpha=.6) - ax.set_xlabel('x') - ax.set_ylabel('y') - ax.set_title('Hover mouse to reveal context', size=20) + ax.set_xlabel('x') + ax.set_ylabel('y') + ax.set_title('Hover mouse to reveal context', size=20) - tooltip = mpld3.plugins.PointHTMLTooltip( - points[0], - contexts, - voffset=10, - hoffset=10 - ) + tooltip = mpld3.plugins.PointHTMLTooltip( + points[0], + contexts, + voffset=10, + hoffset=10 + ) - mpld3.plugins.connect(fig, tooltip) + mpld3.plugins.connect(fig, tooltip) - mpld3.save_html(fig, file) + mpld3.save_html(fig, file) diff --git a/requirements.txt b/requirements.txt index 350fc679cb..3adfd24aab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,8 +3,9 @@ awscli==1.14.32 gensim==3.4.0 typing==3.6.4 pytest==3.6.4 -tqdm -segtok -matplotlib -mpld3 -sklearn \ No newline at end of file +tqdm==4.26.0 +segtok==1.5.7 +matplotlib==3.0.0 +mpld3==0.3 +jinja2==2.10 +sklearn==0.0 \ No newline at end of file diff --git a/tests/test_visual.py b/tests/test_visual.py index 44f8ca0671..46c3b1eb86 100644 --- a/tests/test_visual.py +++ b/tests/test_visual.py @@ -8,48 +8,32 @@ from flair.embeddings import CharLMEmbeddings, StackedEmbeddings import numpy +from flair.visual.manifold import Visualizer, tSNE from flair.visual.training_curves import Plotter -@pytest.mark.skip(reason='Skipping test by default due to long execution time.') -def test_benchmark(): - import time +@pytest.mark.skipif("TRAVIS" in os.environ and os.environ["TRAVIS"] == "true", reason="Skipping this test on Travis CI.") +def test_visualize_word_emeddings(): with open('./resources/visual/snippet.txt') as f: sentences = [x for x in f.read().split('\n') if x] - sentences = [Sentence(x) for x in sentences[:10]] + sentences = [Sentence(x) for x in sentences] charlm_embedding_forward = CharLMEmbeddings('news-forward') charlm_embedding_backward = CharLMEmbeddings('news-backward') - embeddings = StackedEmbeddings( - [charlm_embedding_backward, charlm_embedding_forward] - ) - - tic = time.time() - - prepare_word_embeddings(embeddings, sentences) - - current_elaped = time.time() - tic - - print('current implementation: {} sec/ sentence'.format(current_elaped / 10)) - - embeddings_f = CharLMEmbeddings('news-forward') - embeddings_b = CharLMEmbeddings('news-backward') - - tic = time.time() - - prepare_char_embeddings(embeddings_f, sentences) - prepare_char_embeddings(embeddings_b, sentences) + embeddings = StackedEmbeddings([charlm_embedding_backward, charlm_embedding_forward]) - current_elaped = time.time() - tic + visualizer = Visualizer() + visualizer.visualize_word_emeddings(embeddings, sentences, './resources/visual/sentence_embeddings.html') - print('pytorch implementation: {} sec/ sentence'.format(current_elaped / 10)) + # clean up directory + os.remove('./resources/visual/sentence_embeddings.html') @pytest.mark.skipif("TRAVIS" in os.environ and os.environ["TRAVIS"] == "true", reason="Skipping this test on Travis CI.") -def test_show_word_embeddings(): +def test_visualize_word_emeddings(): with open('./resources/visual/snippet.txt') as f: sentences = [x for x in f.read().split('\n') if x] @@ -57,24 +41,16 @@ def test_show_word_embeddings(): sentences = [Sentence(x) for x in sentences] charlm_embedding_forward = CharLMEmbeddings('news-forward') - charlm_embedding_backward = CharLMEmbeddings('news-backward') - - embeddings = StackedEmbeddings([charlm_embedding_backward, charlm_embedding_forward]) - X = prepare_word_embeddings(embeddings, sentences) - contexts = word_contexts(sentences) - - trans_ = tSNE() - reduced = trans_.fit(X) - - visualize(reduced, contexts, './resources/visual/sentence_embeddings.html') + visualizer = Visualizer() + visualizer.visualize_char_emeddings(charlm_embedding_forward, sentences, './resources/visual/sentence_embeddings.html') # clean up directory os.remove('./resources/visual/sentence_embeddings.html') @pytest.mark.skipif("TRAVIS" in os.environ and os.environ["TRAVIS"] == "true", reason="Skipping this test on Travis CI.") -def test_show_char_embeddings(): +def test_visualize(): with open('./resources/visual/snippet.txt') as f: sentences = [x for x in f.read().split('\n') if x] @@ -83,50 +59,27 @@ def test_show_char_embeddings(): embeddings = CharLMEmbeddings('news-forward') - X_forward = prepare_char_embeddings(embeddings, sentences) + visualizer = Visualizer() + + X_forward = visualizer.prepare_char_embeddings(embeddings, sentences) embeddings = CharLMEmbeddings('news-backward') - X_backward = prepare_char_embeddings(embeddings, sentences) + X_backward = visualizer.prepare_char_embeddings(embeddings, sentences) X = numpy.concatenate([X_forward, X_backward], axis=1) - contexts = char_contexts(sentences) + contexts = visualizer.char_contexts(sentences) trans_ = tSNE() reduced = trans_.fit(X) - visualize(reduced, contexts, './resources/visual/char_embeddings.html') + visualizer.visualize(reduced, contexts, './resources/visual/char_embeddings.html') # clean up directory os.remove('./resources/visual/char_embeddings.html') -@pytest.mark.skipif("TRAVIS" in os.environ and os.environ["TRAVIS"] == "true", reason="Skipping this test on Travis CI.") -def test_show_uni_sentence_embeddings(): - - with open('./resources/visual/snippet.txt') as f: - sentences = [x for x in f.read().split('\n') if x] - - sentences = [Sentence(x) for x in sentences] - - embeddings = CharLMEmbeddings('news-forward') - - X = prepare_char_embeddings(embeddings, sentences) - - trans_ = tSNE() - reduced = trans_.fit(X) - - l = len(sentences[0]) - - contexts = char_contexts(sentences) - - visualize(reduced[:l], contexts[:l], './resources/visual/uni_sentence_embeddings.html') - - # clean up directory - os.remove('./resources/visual/uni_sentence_embeddings.html') - - def test_highlighter(): with open('./resources/visual/snippet.txt') as f: sentences = [x for x in f.read().split('\n') if x]