Skip to content

Commit

Permalink
chore: update files
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Feb 18, 2025
1 parent e044c3f commit f7c0ab1
Show file tree
Hide file tree
Showing 18 changed files with 188 additions and 134 deletions.
9 changes: 4 additions & 5 deletions python/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# from rapidocr_onnxruntime import RapidOCR, VisRes
# from rapidocr_torch import RapidOCR, VisRes


# from rapidocr_paddle import RapidOCR, VisRes
# from rapidocr_openvino import RapidOCR, VisRes

Expand All @@ -18,7 +17,8 @@
# engine = engine = RapidOCR(
# params={"Global.with_onnx": True, "EngineConfig.onnxruntime.use_cuda": True}
# )
engine = RapidOCR(params={"Global.with_torch": True, "Global.lang": "ch"})
# engine = RapidOCR(params={"Global.with_torch": True, "Global.lang": "ch"})
engine = RapidOCR(params={"Global.with_openvino": True, "Global.lang": "ch"})
vis = VisRes()

image_path = "tests/test_files/ch_en_num.jpg"
Expand All @@ -34,11 +34,10 @@
txts = result.txts
scores = result.scores

font_path = "resources/fonts/FZYTK.TTF"
vis_img = vis(img, result.boxes, result.txts, result.scores, font_path)
vis_img = vis(img, result.boxes, result.txts, result.scores)
cv2.imwrite("vis.png", vis_img)

words_results = result.word_results
words, words_scores, words_boxes = list(zip(*words_results))
vis_img = vis(img, words_boxes, words, words_scores, font_path)
vis_img = vis(img, words_boxes, words, words_scores)
cv2.imwrite("vis_single.png", vis_img)
4 changes: 3 additions & 1 deletion python/rapidocr/ch_ppocr_rec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from rapidocr.inference_engine.base import get_engine

from ..utils import Logger, download_file
from .utils import CTCLabelDecode, TextRecInput, TextRecOutput

DEFAULT_DICT_PATH = Path(__file__).parent.parent / "models" / "ppocr_keys_v1.txt"
Expand All @@ -30,6 +31,7 @@
class TextRecognizer:
def __init__(self, config: Dict[str, Any]):
self.session = get_engine(config.engine_name)(config, mode="rec")
self.logger = Logger(logger_name=__name__).get_log()

character = None
if self.session.have_key():
Expand All @@ -38,7 +40,7 @@ def __init__(self, config: Dict[str, Any]):
dict_path = config.get("rec_keys_path", None)
character_dict_path = dict_path if dict_path else DEFAULT_DICT_PATH
if not Path(character_dict_path).exists():
self.session.download_file(DEFAULT_DICT_URL, character_dict_path)
download_file(DEFAULT_DICT_URL, character_dict_path, self.logger)

self.postprocess_op = CTCLabelDecode(
character=character, character_path=character_dict_path
Expand Down
2 changes: 2 additions & 0 deletions python/rapidocr/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Global:
with_paddle: false
with_torch: false

font_path: null

EngineConfig:
onnxruntime:
intra_op_num_threads: -1
Expand Down
14 changes: 0 additions & 14 deletions python/rapidocr/default_models.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
onnxruntime:
PP-OCRv4:
det:
default: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/onnx/PP-OCRv4/det/ch_PP-OCRv4_det_infer.onnx
ch_PP-OCRv4_det_infer.onnx: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/onnx/PP-OCRv4/det/ch_PP-OCRv4_det_infer.onnx
en_PP-OCRv3_det_infer.onnx: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/onnx/PP-OCRv4/det/en_PP-OCRv3_det_infer.onnx
Multilingual_PP-OCRv3_det_infer.onnx: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/onnx/PP-OCRv4/det/Multilingual_PP-OCRv3_det_infer.onnx
rec:
default: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/onnx/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer.onnx
arabic_PP-OCRv4_rec_infer.onnx: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/onnx/PP-OCRv4/rec/arabic_PP-OCRv4_rec_infer.onnx
ch_PP-OCRv4_rec_infer.onnx: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/onnx/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer.onnx
chinese_cht_PP-OCRv3_rec_infer.onnx: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/onnx/PP-OCRv4/rec/chinese_cht_PP-OCRv3_rec_infer.onnx
Expand All @@ -25,14 +23,10 @@ onnxruntime:
openvino:
PP-OCRv4:
det:
default: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/onnx/PP-OCRv4/det/ch_PP-OCRv4_det_infer.onnx
ch_PP-OCRv4_det_infer.onnx: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/onnx/PP-OCRv4/det/ch_PP-OCRv4_det_infer.onnx
en_PP-OCRv3_det_infer.onnx: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/onnx/PP-OCRv4/det/en_PP-OCRv3_det_infer.onnx
Multilingual_PP-OCRv3_det_infer.onnx: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/onnx/PP-OCRv4/det/Multilingual_PP-OCRv3_det_infer.onnx
rec:
default:
model_dir: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/onnx/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer.onnx
dict_url: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/paddle/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer/ppocr_keys_v1.txt
arabic_PP-OCRv4_rec_infer:
model_dir: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/onnx/PP-OCRv4/rec/arabic_PP-OCRv4_rec_infer.onnx
dict_url: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/paddle/PP-OCRv4/rec/arabic_PP-OCRv4_rec_infer/arabic_dict.txt
Expand Down Expand Up @@ -75,14 +69,10 @@ openvino:
paddlepaddle:
PP-OCRv4:
det:
default: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/paddle/PP-OCRv4/det/ch_PP-OCRv4_det_infer
ch_PP-OCRv4_det_infer: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/paddle/PP-OCRv4/det/ch_PP-OCRv4_det_infer
en_PP-OCRv3_det_infer: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/paddle/PP-OCRv4/det/en_PP-OCRv3_det_infer
Multilingual_PP-OCRv3_det_infer: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/paddle/PP-OCRv4/det/Multilingual_PP-OCRv3_det_infer
rec:
default:
model_dir: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/paddle/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer
dict_url: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/paddle/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer/ppocr_keys_v1.txt
arabic_PP-OCRv4_rec_infer:
model_dir: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/paddle/PP-OCRv4/rec/arabic_PP-OCRv4_rec_infer
dict_url: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/paddle/PP-OCRv4/rec/arabic_PP-OCRv4_rec_infer/arabic_dict.txt
Expand Down Expand Up @@ -125,13 +115,9 @@ paddlepaddle:
torch:
PP-OCRv4:
det:
default: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/torch/ch_PP-OCRv4_det_infer.pth
ch_PP-OCRv4_det_infer.pth: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/torch/ch_PP-OCRv4_det_infer.pth
ch_PP-OCRv4_det_server_infer.pth: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/torch/ch_PP-OCRv4_det_server_infer.pth
rec:
default:
model_dir: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/torch/ch_PP-OCRv4_rec_infer.pth
dict_url: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/paddle/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer/ppocr_keys_v1.txt
ch_PP-OCRv4_rec_infer.pth:
model_dir: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/torch/ch_PP-OCRv4_rec_infer.pth
dict_url: https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/master/paddle/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer/ppocr_keys_v1.txt
Expand Down
31 changes: 0 additions & 31 deletions python/rapidocr/inference_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from typing import Union

import numpy as np
import requests
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm

from ..utils.logger import Logger

Expand Down Expand Up @@ -103,31 +101,6 @@ def _verify_model(model_path: Union[str, Path, None]):
def have_key(self, key: str = "character") -> bool:
pass

@classmethod
def download_file(cls, url: str, save_path: Union[str, Path]):
if not Path(save_path).parent.exists():
Path(save_path).parent.mkdir(parents=True, exist_ok=True)

if Path(save_path).exists():
cls.logger.info("Model already exists in %s", save_path)
return

cls.logger.info("Downloading model from %s to %s", url, save_path)

response = requests.get(url, stream=True, timeout=60)
status_code = response.status_code

if status_code != 200:
raise DownloadFileError("Something went wrong while downloading models")

total_size_in_bytes = int(response.headers.get("content-length", 1))
block_size = 1024 # 1 Kibibyte
with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as pb:
with open(save_path, "wb") as file:
for data in response.iter_content(block_size):
pb.update(len(data))
file.write(data)

@classmethod
def get_model_url(cls, engine_name: str, task_type: str, lang: str) -> str:
model_dict = cls.model_info[engine_name]["PP-OCRv4"][task_type]
Expand All @@ -136,7 +109,3 @@ def get_model_url(cls, engine_name: str, task_type: str, lang: str) -> str:
if lang == prefix:
return model_dict[k]
raise KeyError("Model not found")


class DownloadFileError(Exception):
pass
3 changes: 2 additions & 1 deletion python/rapidocr/inference_engine/onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_device,
)

from ..utils import download_file
from ..utils.logger import Logger
from .base import InferSession

Expand All @@ -38,7 +39,7 @@ def __init__(self, config: Dict[str, Any], mode: Optional[str] = None):
config.engine_name, config.task_type, config.lang
)
model_path = self.DEFAULT_MODE_PATH / Path(default_model_url).name
self.download_file(default_model_url, model_path)
download_file(default_model_url, model_path, self.logger)

self._verify_model(model_path)

Expand Down
4 changes: 3 additions & 1 deletion python/rapidocr/inference_engine/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from omegaconf import DictConfig
from openvino.runtime import Core

from ..utils import Logger, download_file
from .base import InferSession


class OpenVINOInferSession(InferSession):
def __init__(self, config: DictConfig, mode: Optional[str] = None):
super().__init__(config)
self.mode = mode
self.logger = Logger(logger_name=__name__).get_log()

core = Core()

Expand All @@ -29,7 +31,7 @@ def __init__(self, config: DictConfig, mode: Optional[str] = None):
default_model_url = default_model_url.model_dir

model_path = self.DEFAULT_MODE_PATH / Path(default_model_url).name
self.download_file(default_model_url, model_path)
download_file(default_model_url, model_path, self.logger)

self._verify_model(model_path)
model_onnx = core.read_model(model_path)
Expand Down
5 changes: 3 additions & 2 deletions python/rapidocr/inference_engine/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import paddle
from paddle import inference

from ..utils import download_file
from ..utils.logger import Logger
from .base import InferSession

Expand All @@ -35,13 +36,13 @@ def __init__(self, config, mode: Optional[str] = None) -> None:
pdmodel_path = (
self.DEFAULT_MODE_PATH / Path(default_model_url).name / PDMODEL_NAME
)
self.download_file(pd_model_url, pdmodel_path)
download_file(pd_model_url, pdmodel_path, self.logger)

pdiparams_url = f"{default_model_url}/{PDIPARAMS_NAME}"
pdiparams_path = (
self.DEFAULT_MODE_PATH / Path(default_model_url).name / PDIPARAMS_NAME
)
self.download_file(pdiparams_url, pdiparams_path)
download_file(pdiparams_url, pdiparams_path, self.logger)
else:
pdmodel_path = model_dir / "inference.pdmodel"
pdiparams_path = model_dir / "inference.pdiparams"
Expand Down
22 changes: 16 additions & 6 deletions python/rapidocr/inference_engine/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ..networks.architectures.base_model import BaseModel
from ..utils.logger import Logger
from ..utils.utils import download_file
from .base import InferSession

root_dir = Path(__file__).resolve().parent.parent
Expand All @@ -19,13 +20,22 @@
class TorchInferSession(InferSession):
def __init__(self, config, mode: Optional[str] = None) -> None:
self.logger = Logger(logger_name=__name__).get_log()
self.mode = mode

all_arch_config = OmegaConf.load(DEFAULT_CFG_PATH)
model_path = config.get("model_path", None)
if model_path is None:
default_model_url = self.get_model_url(
config.engine_name, config.task_type, config.lang
)
if self.mode == "rec":
default_model_url = default_model_url["model_dir"]

model_path = self.DEFAULT_MODE_PATH / Path(default_model_url).name
download_file(default_model_url, model_path, self.logger)

self.mode = mode
model_path = Path(config["model_path"])
self._verify_model(model_path)

all_arch_config = OmegaConf.load(DEFAULT_CFG_PATH)
file_name = model_path.stem
if file_name not in all_arch_config:
raise ValueError(f"architecture {file_name} is not in arch_config.yaml")
Expand All @@ -36,7 +46,7 @@ def __init__(self, config, mode: Optional[str] = None) -> None:
self.predictor.eval()

self.use_gpu = False
if config["use_cuda"]:
if config.engine_cfg.use_cuda:
self.predictor.cuda()
self.use_gpu = True

Expand All @@ -47,8 +57,8 @@ def __call__(self, img: np.ndarray):
inp = inp.cuda()

# 适配跟onnx对齐取值逻辑
outputs = self.predictor(inp).unsqueeze(0)
return outputs.cpu().numpy()
outputs = self.predictor(inp).cpu().numpy()
return outputs

def have_key(self, key: str = "character") -> bool:
return False
Expand Down
3 changes: 0 additions & 3 deletions python/rapidocr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,6 @@ def main():
)
logger.info(result)

if args.print_cost:
logger.info(elapse_list)

if args.vis_res:
vis = VisRes()
Path(args.vis_save_path).mkdir(parents=True, exist_ok=True)
Expand Down
10 changes: 5 additions & 5 deletions python/rapidocr/networks/heads/rec_multi_head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torch import nn

from ..necks.rnn import Im2Seq, SequenceEncoder
from .rec_ctc_head import CTCHead


class FCTranspose(nn.Module):
Expand Down Expand Up @@ -43,15 +44,14 @@ def __init__(self, in_channels, out_channels_list, **kwargs):
head_args = self.head_list[idx][name].get("Head", {})
if head_args is None:
head_args = {}
self.ctc_head = eval(name)(

self.ctc_head = CTCHead(
in_channels=self.ctc_encoder.out_channels,
out_channels=out_channels_list["CTCLabelDecode"],
**head_args
**head_args,
)
else:
raise NotImplementedError(
"{} is not supported in MultiHead yet".format(name)
)
raise NotImplementedError(f"{name} is not supported in MultiHead yet")

def forward(self, x, data=None):
ctc_encoder = self.ctc_encoder(x)
Expand Down
1 change: 1 addition & 0 deletions python/rapidocr/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
from .parse_parameters import ParseParams, init_args, parse_lang
from .process_img import add_round_letterbox, increase_min_side, reduce_max_side
from .typings import RapidOCROutput
from .utils import download_file
from .vis_res import VisRes
Loading

0 comments on commit f7c0ab1

Please sign in to comment.