Skip to content

Commit 1bd2658

Browse files
Merge pull request #219 from megagonlabs/feature/auto_installing_in_ensure_model
auto installing models when ensure_model option is specified but the model package not installed
2 parents 32c5c79 + e5f5b28 commit 1bd2658

File tree

4 files changed

+39
-24
lines changed

4 files changed

+39
-24
lines changed

ginza/analyzer.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,15 @@ def try_sudachi_import(split_mode: str):
3838
class Analyzer:
3939
def __init__(
4040
self,
41-
model_path: str,
42-
ensure_model: str,
41+
model_name_or_path: str,
4342
split_mode: str,
4443
hash_comment: str,
4544
output_format: str,
4645
require_gpu: bool,
4746
disable_sentencizer: bool,
4847
use_normalized_form: bool,
4948
) -> None:
50-
self.model_path = model_path
51-
self.ensure_model = ensure_model
49+
self.model_name_or_path = model_name_or_path
5250
self.split_mode = split_mode
5351
self.hash_comment = hash_comment
5452
self.output_format = output_format
@@ -68,22 +66,16 @@ def set_nlp(self) -> None:
6866
nlp = try_sudachi_import(self.split_mode)
6967
else:
7068
# Work-around for pickle error. Need to share model data.
71-
if self.model_path:
72-
nlp = spacy.load(self.model_path)
73-
elif self.ensure_model:
74-
nlp = spacy.load(self.ensure_model.replace("-", "_"))
69+
if self.model_name_or_path:
70+
nlp = spacy.load(self.model_name_or_path)
7571
else:
7672
try:
7773
nlp = spacy.load("ja_ginza_electra")
7874
except IOError as e:
7975
try:
8076
nlp = spacy.load("ja_ginza")
8177
except IOError as e:
82-
print(
83-
'Could not find the model. You need to install "ja-ginza-electra" or "ja-ginza" by executing pip like `pip install ja-ginza-electra`.',
84-
file=sys.stderr,
85-
)
86-
raise e
78+
raise OSError("E050", 'You need to install "ja-ginza" or "ja-ginza-electra" by executing `pip install ja-ginza`.')
8779

8880
if self.disable_sentencizer:
8981
nlp.add_pipe("disable_sentencizer", before="parser")

ginza/command_line.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from multiprocessing import Process, Queue, Event, cpu_count
33
from pathlib import Path
44
import queue
5+
import re
56
import sys
67
import traceback
78
from typing import Generator, Iterable, Optional, List
@@ -10,6 +11,8 @@
1011
from .analyzer import Analyzer
1112

1213
MINI_BATCH_SIZE = 100
14+
GINZA_MODEL_PATTERN = re.compile(r"^(ja_ginza|ja_ginza_electra)$")
15+
SPACY_MODEL_PATTERN = re.compile(r"^[a-z]{2}[-_].+[-_].+(sm|md|lg|trf)$")
1316

1417

1518
class _OutputWrapper:
@@ -61,7 +64,6 @@ def run(
6164
parallel_level: int = 1,
6265
files: List[str] = None,
6366
):
64-
assert model_path is None or ensure_model is None
6567
if output_format in ["3", "json"] and hash_comment != "analyze":
6668
print(
6769
f'hash_comment="{hash_comment}" not permitted for JSON output. Forced to use hash_comment="analyze".',
@@ -86,9 +88,31 @@ def run(
8688
print("GPU enabled", file=sys.stderr)
8789
parallel_level = level
8890

91+
assert model_path is None or ensure_model is None
92+
if ensure_model:
93+
ensure_model = ensure_model.replace("-", "_")
94+
try:
95+
from importlib import import_module
96+
import_module(ensure_model)
97+
except ModuleNotFoundError:
98+
if GINZA_MODEL_PATTERN.match(ensure_model):
99+
print("Installing", ensure_model, file=sys.stderr)
100+
import pip
101+
pip.main(["install", ensure_model])
102+
print("Successfully installed", ensure_model, file=sys.stderr)
103+
elif SPACY_MODEL_PATTERN.match(ensure_model):
104+
print("Installing", ensure_model, file=sys.stderr)
105+
from spacy.cli.download import download
106+
download(ensure_model)
107+
print("Successfully installed", ensure_model, file=sys.stderr)
108+
else:
109+
raise OSError("E050", f'You need to install "{ensure_model}" before executing ginza.')
110+
model_name_or_path = ensure_model
111+
else:
112+
model_name_or_path = model_path
113+
89114
analyzer = Analyzer(
90-
model_path,
91-
ensure_model,
115+
model_name_or_path,
92116
split_mode,
93117
hash_comment,
94118
output_format,
@@ -288,7 +312,7 @@ def main_ginzame():
288312

289313
@plac.annotations(
290314
model_path=("model directory path", "option", "b", str),
291-
ensure_model=("select model either ja_ginza or ja_ginza_electra", "option", "m", str, ["ja_ginza", "ja-ginza", "ja_ginza_electra", "ja-ginza-electra", None]),
315+
ensure_model=("select model package of ginza or spacy", "option", "m", str),
292316
split_mode=("split mode", "option", "s", str, ["A", "B", "C"]),
293317
hash_comment=("hash comment", "option", "c", str, ["print", "skip", "analyze"]),
294318
output_path=("output path", "option", "o", Path),

ginza/tests/test_analyzer.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
@pytest.fixture
3131
def analyzer() -> Analyzer:
3232
default_params = dict(
33-
model_path=None,
34-
ensure_model=None,
33+
model_name_or_path=None,
3534
split_mode=None,
3635
hash_comment="print",
3736
output_format="conllu",
@@ -80,15 +79,15 @@ def _tokens_json(result: str):
8079
return ret
8180

8281
class TestAnalyzer:
83-
def test_model_path(self, mocker, analyzer):
82+
def test_model_name_or_path_ja_ginza(self, mocker, analyzer):
8483
spacy_load_mock = mocker.patch("spacy.load")
85-
analyzer.model_path = "ja_ginza"
84+
analyzer.model_name_or_path = "ja_ginza"
8685
analyzer.set_nlp()
8786
spacy_load_mock.assert_called_once_with("ja_ginza")
8887

89-
def test_ensure_model(self, mocker, analyzer):
88+
def test_model_name_or_path_ja_ginza_electra(self, mocker, analyzer):
9089
spacy_load_mock = mocker.patch("spacy.load")
91-
analyzer.ensure_model = "ja_ginza_electra"
90+
analyzer.model_name_or_path = "ja_ginza_electra"
9291
analyzer.set_nlp()
9392
spacy_load_mock.assert_called_once_with("ja_ginza_electra")
9493

ginza/tests/test_command_line.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def test_model_path(self, model_path, exit_ok, input_file):
134134
("ja-ginza", True),
135135
("ja-ginza-electra", True),
136136
("ja_ginza_electra", True),
137-
("ja-ginza_electra", False),
137+
("ja-ginza_electra", True),
138138
("not-exist-model", False),
139139
],
140140
)

0 commit comments

Comments
 (0)