Skip to content

Commit

Permalink
GH-61: Wrap embedding visualiation into class. Fix requirement issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
tabergma committed Oct 4, 2018
1 parent d86d471 commit fa4285a
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 140 deletions.
2 changes: 1 addition & 1 deletion flair/visual/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .manifold import tSNE, visualize, prepare_word_embeddings, prepare_char_embeddings, word_contexts, char_contexts
from .manifold import Visualizer
from .activations import Highlighter
2 changes: 1 addition & 1 deletion flair/visual/activations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy


class Highlighter:
class Highlighter(object):

def __init__(self):
self.color_map = [
Expand Down
154 changes: 87 additions & 67 deletions flair/visual/manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,112 +3,132 @@
import numpy


def prepare_word_embeddings(embeddings, sentences):
X = []
class _Transform:
def __init__(self):
pass

print('computing embeddings')
for sentence in tqdm.tqdm(sentences):
embeddings.embed(sentence)
def fit(self, X):
return self.transform.fit_transform(X)

for i, token in enumerate(sentence):
X.append(token.embedding.detach().numpy()[None, :])

X = numpy.concatenate(X, 0)
class tSNE(_Transform):
def __init__(self):
super().__init__()

return X
self.transform = \
TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)


def word_contexts(sentences):
contexts = []
class Visualizer(object):

for sentence in sentences:
def visualize_word_emeddings(self, embeddings, sentences, output_file):
X = self.prepare_word_embeddings(embeddings, sentences)
contexts = self.word_contexts(sentences)

strs = [x.text for x in sentence.tokens]
trans_ = tSNE()
reduced = trans_.fit(X)

for i, token in enumerate(strs):
prop = '<b><font color="red"> {token} </font></b>'.format(
token=token)
self.visualize(reduced, contexts, output_file)

prop = ' '.join(strs[max(i - 4, 0):i]) + prop
prop = prop + ' '.join(strs[i + 1:min(len(strs), i + 5)])
def visualize_char_emeddings(self, embeddings, sentences, output_file):
X = self.prepare_char_embeddings(embeddings, sentences)
contexts = self.char_contexts(sentences)

contexts.append('<p>' + prop + '</p>')
trans_ = tSNE()
reduced = trans_.fit(X)

return contexts
self.visualize(reduced, contexts, output_file)

@staticmethod
def prepare_word_embeddings(embeddings, sentences):
X = []

def prepare_char_embeddings(embeddings, sentences):
X = []
print('computing embeddings')
for sentence in tqdm.tqdm(sentences):
embeddings.embed(sentence)

print('computing embeddings')
for sentence in tqdm.tqdm(sentences):
for i, token in enumerate(sentence):
X.append(token.embedding.detach().numpy()[None, :])

sentence = ' '.join([x.text for x in sentence])
X = numpy.concatenate(X, 0)

hidden = embeddings.lm.get_representation([sentence])
X.append(hidden.squeeze().detach().numpy())
return X

X = numpy.concatenate(X, 0)
@staticmethod
def word_contexts(sentences):
contexts = []

return X
for sentence in sentences:

strs = [x.text for x in sentence.tokens]

def char_contexts(sentences):
contexts = []
for i, token in enumerate(strs):
prop = '<b><font color="red"> {token} </font></b>'.format(
token=token)

for sentence in sentences:
sentence = ' '.join([token.text for token in sentence])
prop = ' '.join(strs[max(i - 4, 0):i]) + prop
prop = prop + ' '.join(strs[i + 1:min(len(strs), i + 5)])

for i, char in enumerate(sentence):
contexts.append('<p>' + prop + '</p>')

context = '<span style="background-color: yellow"><b>{}</b></span>'.format(char)
context = ''.join(sentence[max(i - 30, 0):i]) + context
context = context + ''.join(sentence[i + 1:min(len(sentence), i + 30)])
return contexts

contexts.append(context)
@staticmethod
def prepare_char_embeddings(embeddings, sentences):
X = []

return contexts
print('computing embeddings')
for sentence in tqdm.tqdm(sentences):

sentence = ' '.join([x.text for x in sentence])

class _Transform:
def __init__(self):
pass
hidden = embeddings.lm.get_representation([sentence])
X.append(hidden.squeeze().detach().numpy())

def fit(self, X):
return self.transform.fit_transform(X)
X = numpy.concatenate(X, 0)

return X

class tSNE(_Transform):
def __init__(self):
@staticmethod
def char_contexts(sentences):
contexts = []

super().__init__()
for sentence in sentences:
sentence = ' '.join([token.text for token in sentence])

self.transform = \
TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
for i, char in enumerate(sentence):

context = '<span style="background-color: yellow"><b>{}</b></span>'.format(char)
context = ''.join(sentence[max(i - 30, 0):i]) + context
context = context + ''.join(sentence[i + 1:min(len(sentence), i + 30)])

contexts.append(context)

return contexts

def visualize(X, contexts, file):
import matplotlib.pyplot
import mpld3
@staticmethod
def visualize(X, contexts, file):
import matplotlib.pyplot
import mpld3

fig, ax = matplotlib.pyplot.subplots()
fig, ax = matplotlib.pyplot.subplots()

ax.grid(True, alpha=0.3)
ax.grid(True, alpha=0.3)

points = ax.plot(X[:, 0], X[:, 1], 'o', color='b',
mec='k', ms=5, mew=1, alpha=.6)
points = ax.plot(X[:, 0], X[:, 1], 'o', color='b',
mec='k', ms=5, mew=1, alpha=.6)

ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title('Hover mouse to reveal context', size=20)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title('Hover mouse to reveal context', size=20)

tooltip = mpld3.plugins.PointHTMLTooltip(
points[0],
contexts,
voffset=10,
hoffset=10
)
tooltip = mpld3.plugins.PointHTMLTooltip(
points[0],
contexts,
voffset=10,
hoffset=10
)

mpld3.plugins.connect(fig, tooltip)
mpld3.plugins.connect(fig, tooltip)

mpld3.save_html(fig, file)
mpld3.save_html(fig, file)
11 changes: 6 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ awscli==1.14.32
gensim==3.4.0
typing==3.6.4
pytest==3.6.4
tqdm
segtok
matplotlib
mpld3
sklearn
tqdm==4.26.0
segtok==1.5.7
matplotlib==3.0.0
mpld3==0.3
jinja2==2.10
sklearn==0.0
85 changes: 19 additions & 66 deletions tests/test_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,73 +8,49 @@
from flair.embeddings import CharLMEmbeddings, StackedEmbeddings
import numpy

from flair.visual.manifold import Visualizer, tSNE
from flair.visual.training_curves import Plotter


@pytest.mark.skip(reason='Skipping test by default due to long execution time.')
def test_benchmark():
import time
@pytest.mark.skipif("TRAVIS" in os.environ and os.environ["TRAVIS"] == "true", reason="Skipping this test on Travis CI.")
def test_visualize_word_emeddings():

with open('./resources/visual/snippet.txt') as f:
sentences = [x for x in f.read().split('\n') if x]

sentences = [Sentence(x) for x in sentences[:10]]
sentences = [Sentence(x) for x in sentences]

charlm_embedding_forward = CharLMEmbeddings('news-forward')
charlm_embedding_backward = CharLMEmbeddings('news-backward')

embeddings = StackedEmbeddings(
[charlm_embedding_backward, charlm_embedding_forward]
)

tic = time.time()

prepare_word_embeddings(embeddings, sentences)

current_elaped = time.time() - tic

print('current implementation: {} sec/ sentence'.format(current_elaped / 10))

embeddings_f = CharLMEmbeddings('news-forward')
embeddings_b = CharLMEmbeddings('news-backward')

tic = time.time()

prepare_char_embeddings(embeddings_f, sentences)
prepare_char_embeddings(embeddings_b, sentences)
embeddings = StackedEmbeddings([charlm_embedding_backward, charlm_embedding_forward])

current_elaped = time.time() - tic
visualizer = Visualizer()
visualizer.visualize_word_emeddings(embeddings, sentences, './resources/visual/sentence_embeddings.html')

print('pytorch implementation: {} sec/ sentence'.format(current_elaped / 10))
# clean up directory
os.remove('./resources/visual/sentence_embeddings.html')


@pytest.mark.skipif("TRAVIS" in os.environ and os.environ["TRAVIS"] == "true", reason="Skipping this test on Travis CI.")
def test_show_word_embeddings():
def test_visualize_word_emeddings():

with open('./resources/visual/snippet.txt') as f:
sentences = [x for x in f.read().split('\n') if x]

sentences = [Sentence(x) for x in sentences]

charlm_embedding_forward = CharLMEmbeddings('news-forward')
charlm_embedding_backward = CharLMEmbeddings('news-backward')

embeddings = StackedEmbeddings([charlm_embedding_backward, charlm_embedding_forward])

X = prepare_word_embeddings(embeddings, sentences)
contexts = word_contexts(sentences)

trans_ = tSNE()
reduced = trans_.fit(X)

visualize(reduced, contexts, './resources/visual/sentence_embeddings.html')
visualizer = Visualizer()
visualizer.visualize_char_emeddings(charlm_embedding_forward, sentences, './resources/visual/sentence_embeddings.html')

# clean up directory
os.remove('./resources/visual/sentence_embeddings.html')


@pytest.mark.skipif("TRAVIS" in os.environ and os.environ["TRAVIS"] == "true", reason="Skipping this test on Travis CI.")
def test_show_char_embeddings():
def test_visualize():

with open('./resources/visual/snippet.txt') as f:
sentences = [x for x in f.read().split('\n') if x]
Expand All @@ -83,50 +59,27 @@ def test_show_char_embeddings():

embeddings = CharLMEmbeddings('news-forward')

X_forward = prepare_char_embeddings(embeddings, sentences)
visualizer = Visualizer()

X_forward = visualizer.prepare_char_embeddings(embeddings, sentences)

embeddings = CharLMEmbeddings('news-backward')

X_backward = prepare_char_embeddings(embeddings, sentences)
X_backward = visualizer.prepare_char_embeddings(embeddings, sentences)

X = numpy.concatenate([X_forward, X_backward], axis=1)

contexts = char_contexts(sentences)
contexts = visualizer.char_contexts(sentences)

trans_ = tSNE()
reduced = trans_.fit(X)

visualize(reduced, contexts, './resources/visual/char_embeddings.html')
visualizer.visualize(reduced, contexts, './resources/visual/char_embeddings.html')

# clean up directory
os.remove('./resources/visual/char_embeddings.html')


@pytest.mark.skipif("TRAVIS" in os.environ and os.environ["TRAVIS"] == "true", reason="Skipping this test on Travis CI.")
def test_show_uni_sentence_embeddings():

with open('./resources/visual/snippet.txt') as f:
sentences = [x for x in f.read().split('\n') if x]

sentences = [Sentence(x) for x in sentences]

embeddings = CharLMEmbeddings('news-forward')

X = prepare_char_embeddings(embeddings, sentences)

trans_ = tSNE()
reduced = trans_.fit(X)

l = len(sentences[0])

contexts = char_contexts(sentences)

visualize(reduced[:l], contexts[:l], './resources/visual/uni_sentence_embeddings.html')

# clean up directory
os.remove('./resources/visual/uni_sentence_embeddings.html')


def test_highlighter():
with open('./resources/visual/snippet.txt') as f:
sentences = [x for x in f.read().split('\n') if x]
Expand Down

0 comments on commit fa4285a

Please sign in to comment.