wess09/module/ocr/al_ocr.py

235 lines
7.6 KiB
Python
Raw Normal View History

2020-07-14 23:59:48 +08:00
import os
import numpy as np
2026-03-22 19:22:53 +08:00
import cv2
from PIL import Image
from module.exception import RequestHumanTakeover
from module.logger import logger
2026-03-23 16:30:23 +08:00
from module.config.config import AzurLaneConfig
try:
from rapidocr import RapidOCR, OCRVersion
except Exception as e:
logger.critical(f"Failed to load OCR dependencies: {e}")
logger.critical(
"无法加载 OCR 依赖,请安装微软 C++ 运行库 https://aka.ms/vs/17/release/vc_redist.x64.exe"
)
logger.critical("也有可能是 GPU 不支持加速引起,请尝试关闭 GPU 加速")
logger.critical("如果上述方法都无法解决,请加群获取支持")
2026-03-23 16:30:23 +08:00
raise RequestHumanTakeover
2026-03-29 21:41:08 +08:00
config_name = os.environ.get("ALAS_CONFIG_NAME")
config = AzurLaneConfig(config_name)
USE_GPU = config.ocr_device == 'gpu'
2026-03-23 16:30:23 +08:00
class CnModel:
def __init__(self):
self.params = {
"Global.use_det": False,
"Global.use_cls": False,
"Det.model_path": None,
"Cls.model_path": None,
"Rec.ocr_version": OCRVersion.PPOCRV5,
2026-03-27 18:53:41 +08:00
"Rec.model_path": "bin/ocr_models/zh-CN/alocr-zh-cn-v3.dtk.onnx",
2026-03-23 16:30:23 +08:00
"Rec.rec_keys_path": "bin/ocr_models/zh-CN/cn.txt",
2026-03-29 21:41:08 +08:00
"EngineConfig.onnxruntime.use_dml": USE_GPU
2026-03-23 16:30:23 +08:00
}
2026-03-24 13:13:16 +08:00
self.model = RapidOCR(params=self.params)
2026-03-23 16:30:23 +08:00
2026-03-23 16:30:23 +08:00
class EnModel:
def __init__(self):
self.params = {
"Global.use_det": False,
"Global.use_cls": False,
"Det.model_path": None,
"Cls.model_path": None,
"Rec.ocr_version": OCRVersion.PPOCRV4,
2026-03-28 00:51:44 +08:00
"Rec.model_path": "bin/ocr_models/en-US/alocr-en-us-v2.6.nvc.onnx",
2026-03-23 16:30:23 +08:00
"Rec.rec_keys_path": "bin/ocr_models/en-US/en.txt",
2026-03-29 21:41:08 +08:00
"EngineConfig.onnxruntime.use_dml": USE_GPU
2026-03-23 16:30:23 +08:00
}
2026-03-24 13:13:16 +08:00
self.model = RapidOCR(params=self.params)
2026-03-23 16:30:23 +08:00
class JpModel:
def __init__(self):
self.params = {
"Global.use_det": False,
"Global.use_cls": False,
"Det.model_path": None,
"Cls.model_path": None,
"Rec.ocr_version": OCRVersion.PPOCRV5,
"Rec.model_path": "bin/ocr_models/JP/JP.onnx",
"Rec.rec_keys_path": "bin/ocr_models/JP/ppocrv5_dict.txt",
2026-03-29 21:41:08 +08:00
"EngineConfig.onnxruntime.use_dml": USE_GPU
}
self.model = RapidOCR(params=self.params)
class TwModel:
def __init__(self):
self.params = {
"Global.use_det": False,
"Global.use_cls": False,
"Det.model_path": None,
"Cls.model_path": None,
"Rec.ocr_version": OCRVersion.PPOCRV5,
"Rec.model_path": "bin/ocr_models/TW/TW.onnx",
"Rec.rec_keys_path": "bin/ocr_models/TW/ppocrv5_dict.txt",
2026-03-29 21:41:08 +08:00
"EngineConfig.onnxruntime.use_dml": USE_GPU
}
self.model = RapidOCR(params=self.params)
2026-03-23 16:30:23 +08:00
cn_model = CnModel()
en_model = EnModel()
jp_model = JpModel()
tw_model = TwModel()
2026-03-30 13:14:48 +08:00
def reset_ocr_model():
global cn_model, en_model, jp_model, tw_model, USE_GPU
USE_GPU = config.ocr_device == 'gpu'
logger.info(f"Resetting OCR models, USE_GPU={USE_GPU}")
cn_model = CnModel()
en_model = EnModel()
jp_model = JpModel()
tw_model = TwModel()
2026-03-22 19:22:53 +08:00
class AlOcr:
def __init__(self, **kwargs):
self.model = None
self.name = kwargs.get("name", "en")
2026-03-23 09:31:39 +08:00
self.params = {}
2020-09-08 14:08:04 +08:00
self._model_loaded = False
logger.info(
f"Created AlOcr instance: name='{self.name}', kwargs={kwargs}, PID={os.getpid()}"
)
2026-03-22 19:22:53 +08:00
def init(self):
2026-03-30 13:14:48 +08:00
# We fetch the current global instance instead of assigning a fixed one at construction.
# This allows reset_ocr_model() to work for objects initialized AFTER reset.
if self.name in ["cn", "zhcn"]:
2026-03-23 16:30:23 +08:00
self.model = cn_model.model
elif self.name == "jp":
self.model = jp_model.model
elif self.name == "tw":
self.model = tw_model.model
2026-03-22 19:22:53 +08:00
else:
2026-03-23 16:30:23 +08:00
self.model = en_model.model
2026-03-22 19:22:53 +08:00
self._model_loaded = True
def _ensure_loaded(self):
2020-09-08 14:08:04 +08:00
if not self._model_loaded:
2026-03-22 19:22:53 +08:00
self.init()
2026-03-26 13:14:27 +08:00
def _save_debug_image(self, img, result):
folder = "ocr_debug"
2026-03-26 13:14:27 +08:00
if not os.path.exists(folder):
os.makedirs(folder)
# Get current time for filename uniqueness and sorting
import time
2026-03-26 13:14:27 +08:00
now = int(time.time() * 1000)
# Clean result for filename
res_clean = str(result).replace("\n", " ").replace("\r", " ").strip()
2026-03-26 13:14:27 +08:00
# Remove invalid filename characters, keep some safe ones
res_clean = "".join(
[c for c in res_clean if c.isalnum() or c in (" ", "_", "-")]
).strip()
2026-03-26 13:14:27 +08:00
if not res_clean:
res_clean = "empty"
2026-03-26 13:14:27 +08:00
filename = f"{self.name}_{res_clean}_{now}.png"
filepath = os.path.join(folder, filename)
try:
if isinstance(img, np.ndarray):
cv2.imwrite(filepath, img)
elif isinstance(img, Image.Image):
img.save(filepath)
elif isinstance(img, str) and os.path.exists(img):
import shutil
2026-03-26 13:14:27 +08:00
shutil.copy(img, filepath)
2026-03-26 13:14:27 +08:00
# Limit count to 100
files = [
os.path.join(folder, f)
for f in os.listdir(folder)
if os.path.isfile(os.path.join(folder, f))
]
2026-03-26 13:14:27 +08:00
if len(files) > 100:
files.sort(key=os.path.getmtime)
# Keep the last 100
for f in files[:-100]:
try:
os.remove(f)
except:
pass
except Exception as e:
# We don't want to crash the main process due to debug saving failure
logger.warning(f"Failed to save OCR debug image: {e}")
2026-03-26 13:14:27 +08:00
2026-03-22 19:22:53 +08:00
def ocr(self, img_fp):
2026-03-29 23:47:05 +08:00
logger.debug(f"[VERBOSE] AlOcr.ocr: Ensure loaded...")
2026-03-22 19:22:53 +08:00
self._ensure_loaded()
2026-03-23 11:31:53 +08:00
try:
2026-03-23 16:30:23 +08:00
res = self.model(img_fp)
2026-03-26 13:14:27 +08:00
txt = ""
if hasattr(res, "txts") and res.txts:
2026-03-26 13:14:27 +08:00
txt = res.txts[0]
2026-03-26 13:14:27 +08:00
self._save_debug_image(img_fp, txt)
return txt
2026-03-23 16:30:23 +08:00
except Exception as e:
logger.error(f"AlOcr.ocr exception: {e}")
raise
2020-09-08 14:08:04 +08:00
def ocr_for_single_line(self, img_fp):
2026-03-22 19:22:53 +08:00
return self.ocr(img_fp)
2020-09-08 14:08:04 +08:00
def ocr_for_single_lines(self, img_list):
2026-03-22 19:22:53 +08:00
self._ensure_loaded()
results = []
2026-03-23 11:31:53 +08:00
for i, img in enumerate(img_list):
try:
2026-03-23 16:30:23 +08:00
res = self.model(img)
2026-03-26 13:14:27 +08:00
txt = ""
if hasattr(res, "txts") and res.txts:
2026-03-26 13:14:27 +08:00
txt = res.txts[0]
2026-03-26 13:14:27 +08:00
results.append(txt)
self._save_debug_image(img, txt)
2026-03-23 11:31:53 +08:00
except Exception as e:
2026-03-23 16:30:23 +08:00
logger.error(f"AlOcr.ocr_for_single_lines exception on image {i}: {e}")
raise
2026-03-22 19:22:53 +08:00
return results
2020-09-08 14:08:04 +08:00
def set_cand_alphabet(self, cand_alphabet):
2026-03-22 19:22:53 +08:00
pass
def atomic_ocr(self, img_fp, cand_alphabet=None):
2026-03-22 19:22:53 +08:00
res = self.ocr(img_fp)
if cand_alphabet:
res = "".join([c for c in res if c in cand_alphabet])
2026-03-22 19:22:53 +08:00
return res
def atomic_ocr_for_single_line(self, img_fp, cand_alphabet=None):
2026-03-22 19:22:53 +08:00
res = self.ocr_for_single_line(img_fp)
if cand_alphabet:
res = "".join([c for c in res if c in cand_alphabet])
2026-03-22 19:22:53 +08:00
return res
def atomic_ocr_for_single_lines(self, img_list, cand_alphabet=None):
2026-03-22 19:22:53 +08:00
results = self.ocr_for_single_lines(img_list)
if cand_alphabet:
results = [
"".join([c for c in res if c in cand_alphabet]) for res in results
]
return results