Skip to content

Commit

Permalink
precommit check fails without none type check in embeddings assignment (
Browse files Browse the repository at this point in the history
  • Loading branch information
BeneHTWG authored Jan 10, 2025
1 parent cbfc6c3 commit 83a9de6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
8 changes: 8 additions & 0 deletions src/htwgnlp/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def __init__(self) -> None:
def embedding_values(self) -> np.ndarray:
"""Returns the embedding values.
Raises:
ValueError: if the embeddings have not been loaded yet
Returns:
np.ndarray: the embedding values as a numpy array of shape (n, d), where n is the vocabulary size and d is the number of dimensions
"""
Expand Down Expand Up @@ -77,6 +80,9 @@ def get_embeddings(self, word: str) -> np.ndarray | None:
Args:
word (str): the word to get the embedding vector for
Raises:
ValueError: if the embeddings have not been loaded yet
Returns:
np.ndarray | None: the embedding vector for the given word in the form of a numpy array of shape (d,), where d is the number of dimensions, or None if the word is not in the vocabulary
"""
Expand Down Expand Up @@ -125,6 +131,7 @@ def get_most_similar_words(
metric (Literal["euclidean", "cosine"], optional): the metric to use for computing the similarity. Defaults to "euclidean".
Raises:
ValueError: if the embeddings have not been loaded yet
ValueError: if the metric is not "euclidean" or "cosine"
AssertionError: if the word is not in the vocabulary
Expand All @@ -146,6 +153,7 @@ def find_closest_word(
metric (Literal["euclidean", "cosine"], optional): the metric to use for computing the similarity. Defaults to "euclidean".
Raises:
ValueError: if the embeddings have not been loaded yet
ValueError: if the metric is not "euclidean" or "cosine"
Returns:
Expand Down
21 changes: 17 additions & 4 deletions tests/htwgnlp/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def embeddings():
return WordEmbeddings()


@pytest.fixture
def non_loaded_embeddings():
return WordEmbeddings()


@pytest.fixture
def loaded_embeddings(embeddings):
embeddings._load_raw_embeddings("notebooks/data/embeddings.pkl")
Expand Down Expand Up @@ -47,12 +52,16 @@ def test_load_embeddings_to_dataframe(loaded_embeddings):
assert loaded_embeddings._embeddings_df.shape == (243, 300)


def test_embedding_values(loaded_embeddings):
def test_embedding_values(loaded_embeddings, non_loaded_embeddings):
with pytest.raises(ValueError):
non_loaded_embeddings.embedding_values
assert isinstance(loaded_embeddings.embedding_values, np.ndarray)
assert loaded_embeddings.embedding_values.shape == (243, 300)


def test_get_embeddings(loaded_embeddings):
def test_get_embeddings(loaded_embeddings, non_loaded_embeddings):
with pytest.raises(ValueError):
non_loaded_embeddings.embedding_values
assert isinstance(loaded_embeddings.get_embeddings("happy"), np.ndarray)
assert loaded_embeddings.get_embeddings("happy").shape == (300,)
assert loaded_embeddings.get_embeddings("non_existent_word") is None
Expand Down Expand Up @@ -104,13 +113,17 @@ def test_cosine_similarity(loaded_embeddings, test_vector):
)


def test_find_closest_word(loaded_embeddings, test_vector):
def test_find_closest_word(loaded_embeddings, test_vector, non_loaded_embeddings):
with pytest.raises(ValueError):
non_loaded_embeddings.embedding_values
for metric in ["euclidean", "cosine"]:
assert isinstance(loaded_embeddings.find_closest_word(test_vector, metric), str)
assert loaded_embeddings.find_closest_word(test_vector, metric) == "Bahamas"


def test_get_most_similar_words(loaded_embeddings):
def test_get_most_similar_words(loaded_embeddings, non_loaded_embeddings):
with pytest.raises(ValueError):
non_loaded_embeddings.embedding_values
assert isinstance(loaded_embeddings.get_most_similar_words("Germany"), list)
assert loaded_embeddings.get_most_similar_words("Germany") == [
"Austria",
Expand Down

0 comments on commit 83a9de6

Please sign in to comment.