Skip to content

Commit

Permalink
chore: update files
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Feb 8, 2025
1 parent fc395a5 commit 844b2e6
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 20 deletions.
2 changes: 1 addition & 1 deletion python/rapidocr/ch_ppocr_cls/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import cv2
import numpy as np

from rapidocr.utils import OrtInferSession
from rapidocr.inference_engine import OrtInferSession

from .utils import ClsPostProcess, TextClsOutput

Expand Down
2 changes: 1 addition & 1 deletion python/rapidocr/ch_ppocr_det/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np

from rapidocr.utils import OrtInferSession
from rapidocr.inference_engine import OrtInferSession

from .utils import DBPostProcess, DetPreProcess, TextDetOutput

Expand Down
2 changes: 1 addition & 1 deletion python/rapidocr/ch_ppocr_rec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import cv2
import numpy as np

from rapidocr.utils import OrtInferSession
from rapidocr.inference_engine import OrtInferSession

from .utils import CTCLabelDecode, TextRecInput, TextRecOutput

Expand Down
1 change: 1 addition & 0 deletions python/rapidocr/inference_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
from .onnxruntime import OrtInferSession
30 changes: 30 additions & 0 deletions python/rapidocr/inference_engine/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
import abc
from pathlib import Path
from typing import Union

import numpy as np


class InferSession(abc.ABC):
@abc.abstractmethod
def __init__(self, config):
pass

@abc.abstractmethod
def __call__(self, input_content: np.ndarray) -> np.ndarray:
pass

@staticmethod
def _verify_model(model_path: Union[str, Path, None]):
if model_path is None:
raise ValueError("model_path is None!")

model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f"{model_path} does not exists.")

if not model_path.is_file():
raise FileExistsError(f"{model_path} is not a file.")
18 changes: 3 additions & 15 deletions python/rapidocr/inference_engine/onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import platform
import traceback
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple

import numpy as np
from onnxruntime import (
Expand All @@ -18,6 +17,7 @@
)

from ..utils.logger import get_logger
from .base import InferSession


class EP(Enum):
Expand All @@ -26,7 +26,7 @@ class EP(Enum):
DIRECTML_EP = "DmlExecutionProvider"


class OrtInferSession:
class OrtInferSession(InferSession):
def __init__(self, config: Dict[str, Any]):
self.logger = get_logger("OrtInferSession")

Expand Down Expand Up @@ -214,18 +214,6 @@ def have_key(self, key: str = "character") -> bool:
return True
return False

@staticmethod
def _verify_model(model_path: Union[str, Path, None]):
if model_path is None:
raise ValueError("model_path is None!")

model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f"{model_path} does not exists.")

if not model_path.is_file():
raise FileExistsError(f"{model_path} is not a file.")


class ONNXRuntimeError(Exception):
pass
62 changes: 62 additions & 0 deletions python/rapidocr/inference_engine/torch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,65 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
from pathlib import Path
from typing import Dict, Optional, Union

import numpy as np
import torch
import yaml

root_dir = Path(__file__).resolve().parent.parent
DEFAULT_CFG_PATH = root_dir / "arch_config.yaml"


def read_yaml(yaml_path: Union[str, Path]) -> Dict[str, Dict]:
with open(yaml_path, "rb") as f:
data = yaml.load(f, Loader=yaml.Loader)
return data


from rapidocr_torch.modeling.architectures.base_model import BaseModel

from ..utils.logger import get_logger


class TorchInferSession:
def __init__(self, config, mode: Optional[str] = None) -> None:
all_arch_config = read_yaml(DEFAULT_CFG_PATH)

self.logger = get_logger("TorchInferSession")
self.mode = mode
model_path = Path(config["model_path"])
self._verify_model(model_path)
file_name = model_path.stem
if file_name not in all_arch_config:
raise ValueError(f"architecture {file_name} is not in config.yaml")
arch_config = all_arch_config[file_name]
self.predictor = BaseModel(arch_config)
self.predictor.load_state_dict(torch.load(model_path, weights_only=True))
self.predictor.eval()
self.use_gpu = False
if config["use_cuda"]:
self.predictor.cuda()
self.use_gpu = True

def __call__(self, img: np.ndarray):
with torch.no_grad():
inp = torch.from_numpy(img)
if self.use_gpu:
inp = inp.cuda()
# 适配跟onnx对齐取值逻辑
outputs = self.predictor(inp).unsqueeze(0)
return outputs.cpu().numpy()

@staticmethod
def _verify_model(model_path):
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f"{model_path} does not exists.")
if not model_path.is_file():
raise FileExistsError(f"{model_path} is not a file.")


class TorchInferError(Exception):
pass
2 changes: 1 addition & 1 deletion python/rapidocr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(self, config_path: Optional[str] = None, **kwargs):
updater = UpdateParameters()
config = updater(config, **kwargs)

det_lang, rec_lang = parse_lang(config.Global.lang)
# 根据选定的语言加载对应的模型
det_lang, rec_lang = parse_lang(config.Global.lang)

self.print_verbose = config.Global.print_verbose
self.text_score = config.Global.text_score
Expand Down
1 change: 0 additions & 1 deletion python/rapidocr/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
from ..inference_engine.onnxruntime import OrtInferSession
from .load_image import LoadImage, LoadImageError
from .logger import get_logger
from .parse_parameters import UpdateParameters, init_args, parse_lang, update_model_path
Expand Down

0 comments on commit 844b2e6

Please sign in to comment.