-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathocr.py
44 lines (36 loc) · 1.55 KB
/
ocr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from models import TesseractOCR, KrakenOCR, EasyOCR
from utils import clean_text
import swifter
def predict(image, ocr_model_name, config):
ocr_models = {
"TesseractOCR": TesseractOCR(config),
"Kraken": KrakenOCR(config),
"EasyOCR": EasyOCR(config)
}
if ocr_model_name not in ocr_models:
raise ValueError(f"""OCR model '{ocr_model_name}' not available.
\n Available models are {list(ocr_models.keys())}""")
ocr_model = ocr_models[ocr_model_name]
ocr_model.initialize()
# image = load_image(image, config)
# preprocessed_image = preprocess_image(image, config)
prediction = ocr_model.predict(image)
cleaned_prediction = clean_text(prediction)
return {"prediction": prediction, "cleaned_prediction": cleaned_prediction}
def batch_predict(images_df, ocr_model_name, config):
ocr_models = {
"TesseractOCR": TesseractOCR(config),
"Kraken": KrakenOCR(config),
"EasyOCR": EasyOCR(config)
}
if ocr_model_name not in ocr_models:
raise ValueError(f"""OCR model '{ocr_model_name}' not available.
\n Available models are {list(ocr_models.keys())}""")
ocr_model = ocr_models[ocr_model_name]
ocr_model.initialize()
results_df = images_df.copy()
results_df[f"prediction_{ocr_model_name}"] = results_df["preprocessed_image"].swifter.apply(
ocr_model.predict)
results_df[f"cleaned_prediction_{ocr_model_name}"] = results_df[f"prediction_{ocr_model_name}"].apply(
clean_text)
return results_df