Skip to content

Commit

Permalink
Implement Analogy and Similarity Queries.
Browse files Browse the repository at this point in the history
  • Loading branch information
sebpuetz committed Jun 2, 2020
1 parent 4dbfb50 commit 79be265
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 11 deletions.
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def run(self):
["src/finalfusion/subword/explicit_indexer.c"])
extensions = [hash_indexers, ngrams, explicit_indexer]

install_requires = ["numpy", "toml"]
if sys.version_info.major == 3 and sys.version_info.minor == 6:
install_requires.append("dataclasses")

setup(name='finalfusion',
author="Sebastian Pütz <[email protected]>, Daniël de Kok <[email protected]>",
classifiers=[
Expand All @@ -81,7 +85,7 @@ def run(self):
cmdclass={'build_ext': cython_build_ext},
description="Interface to finalfusion embeddings",
ext_modules=extensions,
install_requires=["numpy", "toml"],
install_requires=install_requires,
license='BlueOak-1.0.0',
packages=find_packages('src'),
include_package_data=True,
Expand Down
147 changes: 146 additions & 1 deletion src/finalfusion/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""
Finalfusion Embeddings
"""
import heapq
from dataclasses import field, dataclass
from os import PathLike
from typing import Optional, Tuple, List, Union, Any, Iterator
from typing import Optional, Tuple, List, Union, Any, Iterator, Set

import numpy as np

Expand Down Expand Up @@ -418,6 +420,120 @@ def bucket_to_explicit(self) -> 'Embeddings':
storage=NdArray(storage),
norms=self.norms)

def analogy( # pylint: disable=too-many-arguments
self,
word1: str,
word2: str,
word3: str,
k: int = 1,
skip: Set[str] = None) -> Optional[List['SimilarityResult']]:
"""
Perform an analogy query.
This method returns words that are close in vector space the analogy
query `word1` is to `word2` as `word3` is to `?`. More concretely,
it searches embeddings that are similar to:
*embedding(word2) - embedding(word1) + embedding(word3)*
Words specified in ``skip`` are not considered as answers. If ``skip``
is None, the query words ``word1``, ``word2`` and ``word3`` are
excluded.
At most, ``k`` results are returned. ``None`` is returned when no
embedding could be computed for any of the tokens.
Parameters
----------
word1 : str
Word1 is to...
word2 : str
word2 like...
word3 : str
word3 is to the return value
skip : Set[str]
Set of strings which should not be considered as answers. Defaults
to ``None`` which excludes the query strings. To allow the query
strings as answers, pass an empty set.
k : int
Number of answers to return, defaults to 1.
Returns
-------
answers : List[SimilarityResult]
List of answers.
"""
embed_a = self.embedding(word1)
embed_b = self.embedding(word2)
embed_c = self.embedding(word3)
if embed_a is None or embed_b is None or embed_c is None:
return None
diff = embed_b - embed_a
embed_d = embed_c + diff
embed_d /= np.linalg.norm(embed_d)
return self._similarity(
embed_d, k, {word1, word2, word3} if skip is None else skip)

def word_similarity(self, query: str,
k: int = 10) -> Optional[List['SimilarityResult']]:
"""
Retrieves the nearest neighbors of the query string.
The similarity between the embedding of the query and other embeddings
is defined by the dot product of the embeddings. If the vectors are
unit vectors, this is the cosine similarity.
At most, ``k`` results are returned.
Parameters
----------
query : str
The query string
k : int
The number of neighbors to return, defaults to 10.
Returns
-------
neighbours : List[Tuple[str, float], optional
List of tuples with neighbour and similarity measure. None if no
embedding can be found for ``query``.
"""
embed = self.embedding(query)
if embed is None:
return None
return self._similarity(embed, k, {query})

def embedding_similarity(self,
query: np.ndarray,
k: int = 10,
skip: Optional[Set[str]] = None
) -> Optional[List['SimilarityResult']]:
"""
Retrieves the nearest neighbors of the query embedding.
The similarity between the query embedding and other embeddings is
defined by the dot product of the embeddings. If the vectors are unit
vectors, this is the cosine similarity.
At most, ``k`` results are returned.
Parameters
----------
query : str
The query array.
k : int
The number of neighbors to return, defaults to 10.
skip : Set[str], optional
Set of strings that should not be considered as neighbours.
Returns
-------
neighbours : List[Tuple[str, float], optional
List of tuples with neighbour and similarity measure. None if no
embedding can be found for ``query``.
"""
return self._similarity(query, k, set() if skip is None else skip)

def __contains__(self, item):
return item in self._vocab

Expand All @@ -427,6 +543,24 @@ def __iter__(self) -> Union[Iterator[Tuple[str, np.ndarray]], Iterator[
return zip(self._vocab, self._storage, self._norms)
return zip(self._vocab, self._storage)

def _similarity(self, query: np.ndarray, k: int,
skips: Set[str]) -> List['SimilarityResult']:
words = self.storage[:len(self.vocab)] # type: np.ndarray
sims = words.dot(query)
skip_indices = set(skip for skip in (self.vocab.word_index.get(skip)
for skip in skips)
if skip is not None)
partition = sims.argpartition(-k -
len(skip_indices))[-k -
len(skip_indices):]

heap = [] # type: List[SimilarityResult]
for idx in partition:
if idx not in skip_indices:
heapq.heappush(
heap, SimilarityResult(self.vocab.words[idx], sims[idx]))
return heapq.nlargest(k, heap)

def _embedding(self,
idx: Union[int, List[int]],
out: Optional[np.ndarray] = None
Expand Down Expand Up @@ -524,3 +658,14 @@ def load_finalfusion(file: Union[str, bytes, int, PathLike],
f'Expected norms chunk, not {str(chunk_id)}')

return Embeddings(storage, vocab, norms, metadata)


@dataclass(order=True)
class SimilarityResult:
"""
Container for a Similarity result.
The word can be accessed through ``result.word``, the similarity through ``result.similarity``.
"""
word: str = field(compare=False)
similarity: float
29 changes: 20 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@ def tests_root():

@pytest.fixture
def simple_vocab_fifu(tests_root):
yield finalfusion.vocab.load_vocab(tests_root / "data/simple_vocab.fifu")
yield finalfusion.vocab.load_vocab(tests_root / "data" /
"simple_vocab.fifu")


@pytest.fixture
def analogy_fifu(tests_root):
yield finalfusion.load_finalfusion(tests_root / "data" /
"simple_vocab.fifu")


@pytest.fixture
Expand Down Expand Up @@ -45,23 +52,27 @@ def bucket_vocab_embeddings_fifu(tests_root):

@pytest.fixture
def embeddings_text(tests_root):
yield finalfusion.compat.load_text(
os.path.join(tests_root, "data/embeddings.txt"))
yield finalfusion.compat.load_text(tests_root / "data" / "embeddings.txt")


@pytest.fixture
def embeddings_text_dims(tests_root):
yield finalfusion.compat.load_text_dims(
os.path.join(tests_root, "data/embeddings.dims.txt"))
yield finalfusion.compat.load_text_dims(tests_root / "data" /
"embeddings.dims.txt")


@pytest.fixture
def embeddings_w2v(tests_root):
yield finalfusion.compat.load_word2vec(
os.path.join(tests_root, "data/embeddings.w2v"))
yield finalfusion.compat.load_word2vec(tests_root / "data" /
"embeddings.w2v")


@pytest.fixture
def embeddings_ft(tests_root):
yield finalfusion.compat.load_fasttext(
os.path.join(tests_root, "data/fasttext.bin"))
yield finalfusion.compat.load_fasttext(tests_root / "data" /
"fasttext.bin")


@pytest.fixture
def similarity_fifu(tests_root):
yield finalfusion.load_finalfusion(tests_root / "data" / "similarity.fifu")
Binary file added tests/data/similarity.fifu
Binary file not shown.
61 changes: 61 additions & 0 deletions tests/test_analogies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest

ANALOGY_ORDER = [
"Deutschland",
"Westdeutschland",
"Sachsen",
"Mitteldeutschland",
"Brandenburg",
"Polen",
"Norddeutschland",
"Dänemark",
"Schleswig-Holstein",
"Österreich",
"Bayern",
"Thüringen",
"Bundesrepublik",
"Ostdeutschland",
"Preußen",
"Deutschen",
"Hessen",
"Potsdam",
"Mecklenburg",
"Niedersachsen",
"Hamburg",
"Süddeutschland",
"Bremen",
"Russland",
"Deutschlands",
"BRD",
"Litauen",
"Mecklenburg-Vorpommern",
"DDR",
"West-Berlin",
"Saarland",
"Lettland",
"Hannover",
"Rostock",
"Sachsen-Anhalt",
"Pommern",
"Schweden",
"Deutsche",
"deutschen",
"Westfalen",
]


def test_analogies(analogy_fifu):
for idx, analogy in enumerate(
analogy_fifu.analogy("Paris", "Frankreich", "Berlin", 40)):
assert ANALOGY_ORDER[idx] == analogy.word

assert analogy_fifu.analogy("Paris", "Frankreich", "Paris", 1,
{"Paris"})[0].word == "Frankreich"
assert analogy_fifu.analogy("Paris", "Frankreich", "Paris",
1)[0].word != "Frankreich"
assert analogy_fifu.analogy("Frankreich", "Frankreich", "Frankreich", 1,
set())[0].word == "Frankreich"
assert analogy_fifu.analogy("Frankreich", "Frankreich", "Frankreich", 1,
{"Frankreich"})[0].word != "Frankreich"

assert analogy_fifu.analogy("Paris", "OOV", "Paris", 1) is None
90 changes: 90 additions & 0 deletions tests/test_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pytest
import numpy

SIMILARITY_ORDER_STUTTGART_10 = [
"Karlsruhe",
"Mannheim",
"München",
"Darmstadt",
"Heidelberg",
"Wiesbaden",
"Kassel",
"Düsseldorf",
"Leipzig",
"Berlin",
]

SIMILARITY_ORDER = [
"Potsdam",
"Hamburg",
"Leipzig",
"Dresden",
"München",
"Düsseldorf",
"Bonn",
"Stuttgart",
"Weimar",
"Berlin-Charlottenburg",
"Rostock",
"Karlsruhe",
"Chemnitz",
"Breslau",
"Wiesbaden",
"Hannover",
"Mannheim",
"Kassel",
"Köln",
"Danzig",
"Erfurt",
"Dessau",
"Bremen",
"Charlottenburg",
"Magdeburg",
"Neuruppin",
"Darmstadt",
"Jena",
"Wien",
"Heidelberg",
"Dortmund",
"Stettin",
"Schwerin",
"Neubrandenburg",
"Greifswald",
"Göttingen",
"Braunschweig",
"Berliner",
"Warschau",
"Berlin-Spandau",
]


def test_similarity_berlin_40(similarity_fifu):
for idx, sim in enumerate(similarity_fifu.word_similarity("Berlin", 40)):
assert SIMILARITY_ORDER[idx] == sim.word


def test_similarity_stuttgart_10(similarity_fifu):
for idx, sim in enumerate(similarity_fifu.word_similarity("Stuttgart",
10)):
assert SIMILARITY_ORDER_STUTTGART_10[idx] == sim.word


def test_embedding_similarity_stuttgart_10(similarity_fifu):
stuttgart = similarity_fifu.embedding("Stuttgart")
sims = similarity_fifu.embedding_similarity(stuttgart, k=10)
assert sims[0].word == "Stuttgart"

for idx, sim in enumerate(sims[1:]):
assert SIMILARITY_ORDER_STUTTGART_10[idx] == sim.word

for idx, sim in enumerate(
similarity_fifu.embedding_similarity(stuttgart,
skip={"Stuttgart"},
k=10)):
assert SIMILARITY_ORDER_STUTTGART_10[idx] == sim.word


def test_embedding_similarity_incompatible_shapes(similarity_fifu):
incompatible_embed = numpy.ones(1, dtype=numpy.float32)
with pytest.raises(ValueError):
similarity_fifu.embedding_similarity(incompatible_embed)

0 comments on commit 79be265

Please sign in to comment.