From e5d0134b73c97647fa44c48bc077fd8285eed17d Mon Sep 17 00:00:00 2001 From: Benjamin Akhras Date: Mon, 3 Feb 2025 12:55:14 +0100 Subject: [PATCH] improve tests and switch to upstream flair until https://github.com/flairNLP/flair/pull/3608 is part of a new release --- natural_language_processing/tests/conftest.py | 30 +++++++++++++++++++ .../tests/test_analysis.py | 14 +++++++++ natural_language_processing/tests/test_nlp.py | 5 ---- pyproject.toml | 2 +- uv.lock | 8 ++--- 5 files changed, 47 insertions(+), 12 deletions(-) create mode 100644 natural_language_processing/tests/test_analysis.py delete mode 100644 natural_language_processing/tests/test_nlp.py diff --git a/natural_language_processing/tests/conftest.py b/natural_language_processing/tests/conftest.py index e69de29..cefb840 100644 --- a/natural_language_processing/tests/conftest.py +++ b/natural_language_processing/tests/conftest.py @@ -0,0 +1,30 @@ +import os +import pytest +import json + +from natural_language_processing.roberta_ner import RobertaNER +from natural_language_processing.flair_ner import FlairNER + + +@pytest.fixture(scope="session") +def news_items(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + story_json = os.path.join(dir_path, "story_list.json") + with open(story_json) as f: + data = json.load(f) + yield [item["content"] for cluster in data for item in cluster["news_items"] if "content" in item] + + +@pytest.fixture(scope="session") +def flair_analyzer(): + yield FlairNER() + + +@pytest.fixture(scope="session") +def roberta_analyzer(): + yield RobertaNER() + + +@pytest.fixture(scope="session") +def example_text(): + yield "This is an example for NER, about the ACME Corporation which is producing Dynamite in Acme City, which is in Australia and run by Mr. Wile E. Coyote" diff --git a/natural_language_processing/tests/test_analysis.py b/natural_language_processing/tests/test_analysis.py new file mode 100644 index 0000000..b3e8a47 --- /dev/null +++ b/natural_language_processing/tests/test_analysis.py @@ -0,0 +1,14 @@ +from natural_language_processing.roberta_ner import RobertaNER +from natural_language_processing.flair_ner import FlairNER + + +def test_analyze_ner_flair(example_text: str, flair_analyzer: FlairNER): + result = flair_analyzer.predict(example_text) + expected = {"Australia": "LOC", "Wile E. Coyote": "PER"} + assert expected.items() <= result.items() + + +def test_analyze_ner_roberta(example_text: str, roberta_analyzer: RobertaNER): + result = roberta_analyzer.predict(example_text) + expected = {"ACME Corporation": "ORG", "Acme City": "LOC", "Australia": "LOC", "Dynamite": "MISC"} + assert expected.items() <= result.items() diff --git a/natural_language_processing/tests/test_nlp.py b/natural_language_processing/tests/test_nlp.py deleted file mode 100644 index 5e5a387..0000000 --- a/natural_language_processing/tests/test_nlp.py +++ /dev/null @@ -1,5 +0,0 @@ -def test_app(): - from natural_language_processing.__init__ import create_app - - app = create_app() - assert app diff --git a/pyproject.toml b/pyproject.toml index 9baaaac..0dcdb32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "requests", "granian", "sentencepiece", - "flair", + "flair @ git+https://github.com/flairNLP/flair@master", "torch", "huggingface_hub", "pydantic-settings", diff --git a/uv.lock b/uv.lock index 5fa4864..8350f98 100644 --- a/uv.lock +++ b/uv.lock @@ -319,7 +319,7 @@ wheels = [ [[package]] name = "flair" version = "0.15.0" -source = { registry = "https://pypi.org/simple" } +source = { git = "https://github.com/flairNLP/flair?rev=master#e00e0ff73e1036046e8d4e258a066a7a7dc07c89" } dependencies = [ { name = "bioc" }, { name = "boto3" }, @@ -347,10 +347,6 @@ dependencies = [ { name = "transformers", extra = ["sentencepiece"] }, { name = "wikipedia-api" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ee/f7/5ea640ad606ee73942b7f450125312e04440cf3bb57c234430a7c310f705/flair-0.15.0.tar.gz", hash = "sha256:815513edb2b72f15b54ee5659b6316ccb854b3e7fcadb52c807ff1de90e0ae87", size = 379223 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/92/20/fc57c4338ccf67cd4bdc42fce176c0c581c3d739311c062ed585ef849ef4/flair-0.15.0-py3-none-any.whl", hash = "sha256:bfb2f6ab2a355fbd94d03edc9d78f836f2e87fc3bc36d576d686c007513410f5", size = 1167239 }, -] [[package]] name = "flask" @@ -930,7 +926,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "flair" }, + { name = "flair", git = "https://github.com/flairNLP/flair?rev=master" }, { name = "flask" }, { name = "granian" }, { name = "huggingface-hub" },