Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auto installing models when ensure_model option is specified but the model package not installed #219

Merged
18 changes: 5 additions & 13 deletions ginza/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,15 @@ def try_sudachi_import(split_mode: str):
class Analyzer:
def __init__(
self,
model_path: str,
ensure_model: str,
model_name_or_path: str,
split_mode: str,
hash_comment: str,
output_format: str,
require_gpu: bool,
disable_sentencizer: bool,
use_normalized_form: bool,
) -> None:
self.model_path = model_path
self.ensure_model = ensure_model
self.model_name_or_path = model_name_or_path
self.split_mode = split_mode
self.hash_comment = hash_comment
self.output_format = output_format
Expand All @@ -68,22 +66,16 @@ def set_nlp(self) -> None:
nlp = try_sudachi_import(self.split_mode)
else:
# Work-around for pickle error. Need to share model data.
if self.model_path:
nlp = spacy.load(self.model_path)
elif self.ensure_model:
nlp = spacy.load(self.ensure_model.replace("-", "_"))
if self.model_name_or_path:
nlp = spacy.load(self.model_name_or_path)
else:
try:
nlp = spacy.load("ja_ginza_electra")
except IOError as e:
try:
nlp = spacy.load("ja_ginza")
except IOError as e:
print(
'Could not find the model. You need to install "ja-ginza-electra" or "ja-ginza" by executing pip like `pip install ja-ginza-electra`.',
file=sys.stderr,
)
raise e
raise OSError("E050", 'You need to install "ja-ginza" or "ja-ginza-electra" by executing `pip install ja-ginza`.')

if self.disable_sentencizer:
nlp.add_pipe("disable_sentencizer", before="parser")
Expand Down
32 changes: 28 additions & 4 deletions ginza/command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from multiprocessing import Process, Queue, Event, cpu_count
from pathlib import Path
import queue
import re
import sys
import traceback
from typing import Generator, Iterable, Optional, List
Expand All @@ -10,6 +11,8 @@
from .analyzer import Analyzer

MINI_BATCH_SIZE = 100
GINZA_MODEL_PATTERN = re.compile(r"^(ja_ginza|ja_ginza_electra)$")
SPACY_MODEL_PATTERN = re.compile(r"^[a-z]{2}[-_].+[-_].+(sm|md|lg|trf)$")


class _OutputWrapper:
Expand Down Expand Up @@ -61,7 +64,6 @@ def run(
parallel_level: int = 1,
files: List[str] = None,
):
assert model_path is None or ensure_model is None
if output_format in ["3", "json"] and hash_comment != "analyze":
print(
f'hash_comment="{hash_comment}" not permitted for JSON output. Forced to use hash_comment="analyze".',
Expand All @@ -86,9 +88,31 @@ def run(
print("GPU enabled", file=sys.stderr)
parallel_level = level

assert model_path is None or ensure_model is None
if ensure_model:
ensure_model = ensure_model.replace("-", "_")
try:
from importlib import import_module
import_module(ensure_model)
except ModuleNotFoundError:
if GINZA_MODEL_PATTERN.match(ensure_model):
print("Installing", ensure_model, file=sys.stderr)
import pip
pip.main(["install", ensure_model])
print("Successfully installed", ensure_model, file=sys.stderr)
elif SPACY_MODEL_PATTERN.match(ensure_model):
print("Installing", ensure_model, file=sys.stderr)
from spacy.cli.download import download
download(ensure_model)
print("Successfully installed", ensure_model, file=sys.stderr)
else:
raise OSError("E050", f'You need to install "{ensure_model}" before executing ginza.')
model_name_or_path = ensure_model
else:
model_name_or_path = model_path

analyzer = Analyzer(
model_path,
ensure_model,
model_name_or_path,
split_mode,
hash_comment,
output_format,
Expand Down Expand Up @@ -288,7 +312,7 @@ def main_ginzame():

@plac.annotations(
model_path=("model directory path", "option", "b", str),
ensure_model=("select model either ja_ginza or ja_ginza_electra", "option", "m", str, ["ja_ginza", "ja-ginza", "ja_ginza_electra", "ja-ginza-electra", None]),
ensure_model=("select model package of ginza or spacy", "option", "m", str),
split_mode=("split mode", "option", "s", str, ["A", "B", "C"]),
hash_comment=("hash comment", "option", "c", str, ["print", "skip", "analyze"]),
output_path=("output path", "option", "o", Path),
Expand Down
11 changes: 5 additions & 6 deletions ginza/tests/test_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
@pytest.fixture
def analyzer() -> Analyzer:
default_params = dict(
model_path=None,
ensure_model=None,
model_name_or_path=None,
split_mode=None,
hash_comment="print",
output_format="conllu",
Expand Down Expand Up @@ -80,15 +79,15 @@ def _tokens_json(result: str):
return ret

class TestAnalyzer:
def test_model_path(self, mocker, analyzer):
def test_model_name_or_path_ja_ginza(self, mocker, analyzer):
spacy_load_mock = mocker.patch("spacy.load")
analyzer.model_path = "ja_ginza"
analyzer.model_name_or_path = "ja_ginza"
analyzer.set_nlp()
spacy_load_mock.assert_called_once_with("ja_ginza")

def test_ensure_model(self, mocker, analyzer):
def test_model_name_or_path_ja_ginza_electra(self, mocker, analyzer):
spacy_load_mock = mocker.patch("spacy.load")
analyzer.ensure_model = "ja_ginza_electra"
analyzer.model_name_or_path = "ja_ginza_electra"
analyzer.set_nlp()
spacy_load_mock.assert_called_once_with("ja_ginza_electra")

Expand Down
2 changes: 1 addition & 1 deletion ginza/tests/test_command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_model_path(self, model_path, exit_ok, input_file):
("ja-ginza", True),
("ja-ginza-electra", True),
("ja_ginza_electra", True),
("ja-ginza_electra", False),
("ja-ginza_electra", True),
("not-exist-model", False),
],
)
Expand Down