Skip to content

Commit

Permalink
improve tests and switch to upstream flair until flairNLP/flair#3608
Browse files Browse the repository at this point in the history
…is part of a new release
  • Loading branch information
b3n4kh committed Feb 3, 2025
1 parent 36b7c5f commit e5d0134
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 12 deletions.
30 changes: 30 additions & 0 deletions natural_language_processing/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
import pytest
import json

from natural_language_processing.roberta_ner import RobertaNER
from natural_language_processing.flair_ner import FlairNER


@pytest.fixture(scope="session")
def news_items():
dir_path = os.path.dirname(os.path.realpath(__file__))
story_json = os.path.join(dir_path, "story_list.json")
with open(story_json) as f:
data = json.load(f)
yield [item["content"] for cluster in data for item in cluster["news_items"] if "content" in item]


@pytest.fixture(scope="session")
def flair_analyzer():
yield FlairNER()


@pytest.fixture(scope="session")
def roberta_analyzer():
yield RobertaNER()


@pytest.fixture(scope="session")
def example_text():
yield "This is an example for NER, about the ACME Corporation which is producing Dynamite in Acme City, which is in Australia and run by Mr. Wile E. Coyote"
14 changes: 14 additions & 0 deletions natural_language_processing/tests/test_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from natural_language_processing.roberta_ner import RobertaNER
from natural_language_processing.flair_ner import FlairNER


def test_analyze_ner_flair(example_text: str, flair_analyzer: FlairNER):
result = flair_analyzer.predict(example_text)
expected = {"Australia": "LOC", "Wile E. Coyote": "PER"}
assert expected.items() <= result.items()


def test_analyze_ner_roberta(example_text: str, roberta_analyzer: RobertaNER):
result = roberta_analyzer.predict(example_text)
expected = {"ACME Corporation": "ORG", "Acme City": "LOC", "Australia": "LOC", "Dynamite": "MISC"}
assert expected.items() <= result.items()
5 changes: 0 additions & 5 deletions natural_language_processing/tests/test_nlp.py

This file was deleted.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"requests",
"granian",
"sentencepiece",
"flair",
"flair @ git+https://github.com/flairNLP/flair@master",
"torch",
"huggingface_hub",
"pydantic-settings",
Expand Down
8 changes: 2 additions & 6 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit e5d0134

Please sign in to comment.