-
-
Notifications
You must be signed in to change notification settings - Fork 404
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
100 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
# -*- encoding: utf-8 -*- | ||
# @Author: SWHL | ||
# @Contact: [email protected] | ||
from .onnxruntime import OrtInferSession |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# -*- encoding: utf-8 -*- | ||
# @Author: SWHL | ||
# @Contact: [email protected] | ||
import abc | ||
from pathlib import Path | ||
from typing import Union | ||
|
||
import numpy as np | ||
|
||
|
||
class InferSession(abc.ABC): | ||
@abc.abstractmethod | ||
def __init__(self, config): | ||
pass | ||
|
||
@abc.abstractmethod | ||
def __call__(self, input_content: np.ndarray) -> np.ndarray: | ||
pass | ||
|
||
@staticmethod | ||
def _verify_model(model_path: Union[str, Path, None]): | ||
if model_path is None: | ||
raise ValueError("model_path is None!") | ||
|
||
model_path = Path(model_path) | ||
if not model_path.exists(): | ||
raise FileNotFoundError(f"{model_path} does not exists.") | ||
|
||
if not model_path.is_file(): | ||
raise FileExistsError(f"{model_path} is not a file.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,65 @@ | ||
# -*- encoding: utf-8 -*- | ||
# @Author: SWHL | ||
# @Contact: [email protected] | ||
from pathlib import Path | ||
from typing import Dict, Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
import yaml | ||
|
||
root_dir = Path(__file__).resolve().parent.parent | ||
DEFAULT_CFG_PATH = root_dir / "arch_config.yaml" | ||
|
||
|
||
def read_yaml(yaml_path: Union[str, Path]) -> Dict[str, Dict]: | ||
with open(yaml_path, "rb") as f: | ||
data = yaml.load(f, Loader=yaml.Loader) | ||
return data | ||
|
||
|
||
from rapidocr_torch.modeling.architectures.base_model import BaseModel | ||
|
||
from ..utils.logger import get_logger | ||
|
||
|
||
class TorchInferSession: | ||
def __init__(self, config, mode: Optional[str] = None) -> None: | ||
all_arch_config = read_yaml(DEFAULT_CFG_PATH) | ||
|
||
self.logger = get_logger("TorchInferSession") | ||
self.mode = mode | ||
model_path = Path(config["model_path"]) | ||
self._verify_model(model_path) | ||
file_name = model_path.stem | ||
if file_name not in all_arch_config: | ||
raise ValueError(f"architecture {file_name} is not in config.yaml") | ||
arch_config = all_arch_config[file_name] | ||
self.predictor = BaseModel(arch_config) | ||
self.predictor.load_state_dict(torch.load(model_path, weights_only=True)) | ||
self.predictor.eval() | ||
self.use_gpu = False | ||
if config["use_cuda"]: | ||
self.predictor.cuda() | ||
self.use_gpu = True | ||
|
||
def __call__(self, img: np.ndarray): | ||
with torch.no_grad(): | ||
inp = torch.from_numpy(img) | ||
if self.use_gpu: | ||
inp = inp.cuda() | ||
# 适配跟onnx对齐取值逻辑 | ||
outputs = self.predictor(inp).unsqueeze(0) | ||
return outputs.cpu().numpy() | ||
|
||
@staticmethod | ||
def _verify_model(model_path): | ||
model_path = Path(model_path) | ||
if not model_path.exists(): | ||
raise FileNotFoundError(f"{model_path} does not exists.") | ||
if not model_path.is_file(): | ||
raise FileExistsError(f"{model_path} is not a file.") | ||
|
||
|
||
class TorchInferError(Exception): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,6 @@ | ||
# -*- encoding: utf-8 -*- | ||
# @Author: SWHL | ||
# @Contact: [email protected] | ||
from ..inference_engine.onnxruntime import OrtInferSession | ||
from .load_image import LoadImage, LoadImageError | ||
from .logger import get_logger | ||
from .parse_parameters import UpdateParameters, init_args, parse_lang, update_model_path | ||
|