From 6de719464733fe21cc1023e2a12cb8fd99a03f92 Mon Sep 17 00:00:00 2001 From: SWHL Date: Fri, 14 Feb 2025 14:13:48 +0800 Subject: [PATCH] chore: update files --- python/demo.py | 2 +- python/rapidocr/inference_engine/onnxruntime.py | 4 ++-- python/rapidocr/inference_engine/openvino.py | 7 ++++++- python/tests/test_main.py | 9 +++++++++ 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/python/demo.py b/python/demo.py index f90172b8..c4cbda48 100644 --- a/python/demo.py +++ b/python/demo.py @@ -18,7 +18,7 @@ # engine = engine = RapidOCR( # params={"Global.with_onnx": True, "EngineConfig.onnxruntime.use_cuda": True} # ) -engine = RapidOCR(params={"Global.with_paddle": True, "Global.lang": "ch"}) +engine = RapidOCR(params={"Global.with_openvino": True, "Global.lang": "ch"}) vis = VisRes() image_path = "tests/test_files/ch_en_num.jpg" diff --git a/python/rapidocr/inference_engine/onnxruntime.py b/python/rapidocr/inference_engine/onnxruntime.py index 433c70b9..60e3b0f8 100644 --- a/python/rapidocr/inference_engine/onnxruntime.py +++ b/python/rapidocr/inference_engine/onnxruntime.py @@ -6,7 +6,7 @@ import traceback from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np from onnxruntime import ( @@ -28,7 +28,7 @@ class EP(Enum): class OrtInferSession(InferSession): - def __init__(self, config: Dict[str, Any]): + def __init__(self, config: Dict[str, Any], mode: Optional[str] = None): self.logger = Logger(logger_name=__name__).get_log() model_path = config.get("model_path", None) diff --git a/python/rapidocr/inference_engine/openvino.py b/python/rapidocr/inference_engine/openvino.py index 168578a3..3a15dbd2 100644 --- a/python/rapidocr/inference_engine/openvino.py +++ b/python/rapidocr/inference_engine/openvino.py @@ -4,6 +4,7 @@ import os import traceback from pathlib import Path +from typing import Optional import numpy as np from omegaconf import DictConfig @@ -13,8 +14,9 @@ class OpenVINOInferSession(InferSession): - def __init__(self, config: DictConfig): + def __init__(self, config: DictConfig, mode: Optional[str] = None): super().__init__(config) + self.mode = mode core = Core() @@ -23,6 +25,9 @@ def __init__(self, config: DictConfig): default_model_url = self.get_model_url( config.engine_name, config.task_type, config.lang ) + if self.mode == "rec": + default_model_url = default_model_url.model_dir + model_path = self.DEFAULT_MODE_PATH / Path(default_model_url).name self.download_file(default_model_url, model_path) diff --git a/python/tests/test_main.py b/python/tests/test_main.py index 1d597398..6de89f8f 100644 --- a/python/tests/test_main.py +++ b/python/tests/test_main.py @@ -34,15 +34,24 @@ def get_engine(params: Optional[Dict[str, Any]] = None): return engine +def test_lang(): + engine = get_engine(params={"Global.lang": "ch", "Global.with_openvino": True}) + result = engine(img_path) + assert result.txts is not None + assert result.txts[0] == "正品促销" + + def test_engine_openvino(): engine = get_engine(params={"Global.with_openvino": True}) result = engine(img_path) + assert result.txts is not None assert result.txts[0] == "正品促销" def test_engine_paddle(): engine = RapidOCR(params={"Global.with_paddle": True}) result = engine(img_path) + assert result.txts is not None assert result.txts[0] == "正品促销"