Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 55 additions & 9 deletions deepdoc/vision/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def create_operators(op_param_list, global_config=None):


def load_model(model_dir, nm, device_id: int | None = None):
"""Load and cache the ONNX OCR model ``<model_dir>/<nm>.onnx``, picking the CUDA or CPU execution provider."""
model_file_path = os.path.join(model_dir, nm + ".onnx")
model_cached_tag = model_file_path + str(device_id) if device_id is not None else model_file_path

Expand All @@ -83,15 +84,34 @@ def load_model(model_dir, nm, device_id: int | None = None):
model_file_path))

def cuda_is_available():
"""Return True only when onnxruntime-gpu can actually load its CUDA EP on this host."""
try:
pip_install_torch()
import torch
target_id = 0 if device_id is None else device_id
if torch.cuda.is_available() and torch.cuda.device_count() > target_id:
return True
if not (torch.cuda.is_available() and torch.cuda.device_count() > target_id):
return False
except Exception:
return False
return False
# onnxruntime-gpu 1.23.x is built against CUDA 12 + cuDNN 9 and dlopens
# libcublasLt.so.12 / libcudnn.so.9 at provider-load time. When the host
# only ships CUDA 13 user-mode libs (via nvidia-container-toolkit on a
# CUDA-13 host), the CUDA EP fails to register and ORT logs noisy errors
# while silently falling back to CPU. Probe for the cu12 SONAMEs up-front
# so we request CPU explicitly and skip the misleading warnings.
# See https://github.com/infiniflow/ragflow/issues/15687
import ctypes
for soname in ("libcublasLt.so.12", "libcudnn.so.9"):
try:
ctypes.CDLL(soname)
except OSError:
logging.warning(
f"{soname} not found; onnxruntime-gpu CUDA EP requires CUDA 12 + cuDNN 9. "
"Falling back to CPUExecutionProvider for OCR. "
"If you intended GPU inference, install matching libs or use a CUDA 12 host."
)
return False
return True

options = ort.SessionOptions()
options.enable_cpu_mem_arena = False
Expand Down Expand Up @@ -137,7 +157,10 @@ def cuda_is_available():


class TextRecognizer:
"""Recognise the text inside each detected bounding box using the recognition ONNX model."""

def __init__(self, model_dir, device_id: int | None = None):
"""Load the recognition model from ``model_dir`` (optional explicit GPU device id)."""
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
self.rec_batch_num = 16
postprocess_params = {
Expand All @@ -150,6 +173,7 @@ def __init__(self, model_dir, device_id: int | None = None):
self.input_tensor = self.predictor.get_inputs()[0]

def resize_norm_img(self, img, max_wh_ratio):
"""Resize + normalise an input crop to the recognition model's expected shape."""
imgC, imgH, imgW = self.rec_image_shape

assert imgC == img.shape[2]
Expand All @@ -176,7 +200,7 @@ def resize_norm_img(self, img, max_wh_ratio):
return padding_im

def resize_norm_img_vl(self, img, image_shape):

"""Resize + normalise for the visual-language recognizer variant."""
imgC, imgH, imgW = image_shape
img = img[:, :, ::-1] # bgr2rgb
resized_image = cv2.resize(
Expand All @@ -186,6 +210,7 @@ def resize_norm_img_vl(self, img, image_shape):
return resized_image

def resize_norm_img_srn(self, img, image_shape):
"""Resize + normalise for the SRN recognizer variant."""
imgC, imgH, imgW = image_shape

img_black = np.zeros((imgH, imgW))
Expand All @@ -212,7 +237,7 @@ def resize_norm_img_srn(self, img, image_shape):
return np.reshape(img_black, (c, row, col)).astype(np.float32)

def srn_other_inputs(self, image_shape, num_heads, max_text_length):

"""Build the auxiliary SRN inputs (word positions and self-attention bias tensors)."""
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))

Expand Down Expand Up @@ -243,6 +268,7 @@ def srn_other_inputs(self, image_shape, num_heads, max_text_length):
]

def process_image_srn(self, img, image_shape, num_heads, max_text_length):
"""Prepare the image and auxiliary tensors expected by the SRN recognizer."""
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]

Expand All @@ -259,6 +285,7 @@ def process_image_srn(self, img, image_shape, num_heads, max_text_length):

def resize_norm_img_sar(self, img, image_shape,
width_downsample_ratio=0.25):
"""Resize + normalise for the SAR recognizer variant."""
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
Expand Down Expand Up @@ -293,6 +320,7 @@ def resize_norm_img_sar(self, img, image_shape,
return padding_im, resize_shape, pad_shape, valid_ratio

def resize_norm_img_spin(self, img):
"""Resize + normalise for the SPIN recognizer variant."""
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
Expand All @@ -310,7 +338,7 @@ def resize_norm_img_spin(self, img):
return img

def resize_norm_img_svtr(self, img, image_shape):

"""Resize + normalise for the SVTR recognizer variant."""
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
Expand All @@ -321,7 +349,7 @@ def resize_norm_img_svtr(self, img, image_shape):
return resized_image

def resize_norm_img_abinet(self, img, image_shape):

"""Resize + normalise for the ABINet recognizer variant."""
imgC, imgH, imgW = image_shape

resized_image = cv2.resize(
Expand All @@ -339,7 +367,7 @@ def resize_norm_img_abinet(self, img, image_shape):
return resized_image

def norm_img_can(self, img, image_shape):

"""Resize + normalise for the CAN recognizer variant (greyscale)."""
img = cv2.cvtColor(
img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image

Expand All @@ -360,13 +388,14 @@ def norm_img_can(self, img, image_shape):
return img

def close(self):
# close session and release manually
"""Release the ONNX session and free GPU memory."""
logging.info('Close text recognizer.')
if hasattr(self, "predictor"):
del self.predictor
gc.collect()

def __call__(self, img_list):
"""Run text recognition on a batch of cropped images and return ``(rec_res, elapsed_seconds)``."""
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
Expand Down Expand Up @@ -414,11 +443,15 @@ def __call__(self, img_list):
return rec_res, time.time() - st

def __del__(self):
"""Ensure :meth:`close` runs when the instance is garbage-collected."""
self.close()


class TextDetector:
"""Detect text bounding boxes in an image using the detection ONNX model."""

def __init__(self, model_dir, device_id: int | None = None):
"""Load the detection model from ``model_dir`` (optional explicit GPU device id)."""
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': 960,
Expand Down Expand Up @@ -457,6 +490,7 @@ def __init__(self, model_dir, device_id: int | None = None):
self.preprocess_op = create_operators(pre_process_list)

def order_points_clockwise(self, pts):
"""Reorder polygon corner points clockwise: top-left, top-right, bottom-right, bottom-left."""
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
Expand All @@ -468,12 +502,14 @@ def order_points_clockwise(self, pts):
return rect

def clip_det_res(self, points, img_height, img_width):
"""Clamp polygon coordinates to the image bounds."""
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points

def filter_tag_det_res(self, dt_boxes, image_shape):
"""Drop too-small detections and order the remaining boxes clockwise."""
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
Expand All @@ -490,6 +526,7 @@ def filter_tag_det_res(self, dt_boxes, image_shape):
return dt_boxes

def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
"""Like :meth:`filter_tag_det_res` but only clip; never drop boxes for being small."""
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
Expand All @@ -501,12 +538,14 @@ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
return dt_boxes

def close(self):
"""Release the ONNX session and free GPU memory."""
logging.info("Close text detector.")
if hasattr(self, "predictor"):
del self.predictor
gc.collect()

def __call__(self, img):
"""Run text detection on a single image and return ``(dt_boxes, elapsed_seconds)``."""
ori_im = img.copy()
data = {'image': img}

Expand Down Expand Up @@ -536,10 +575,13 @@ def __call__(self, img):
return dt_boxes, time.time() - st

def __del__(self):
"""Ensure :meth:`close` runs when the instance is garbage-collected."""
self.close()


class OCR:
"""End-to-end OCR pipeline: detection → optional crop/rotate → recognition."""

def __init__(self, model_dir=None):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
Expand Down Expand Up @@ -667,6 +709,7 @@ def sorted_boxes(self, dt_boxes):
return _boxes

def detect(self, img, device_id: int | None = None):
"""Detect text boxes in ``img`` without running recognition; returns an iterator of ``(box, ("", 0))`` pairs."""
if device_id is None:
device_id = 0

Expand All @@ -682,6 +725,7 @@ def detect(self, img, device_id: int | None = None):
("", 0) for _ in range(len(dt_boxes))])

def recognize(self, ori_im, box, device_id: int | None = None):
"""Recognise the text inside ``box`` of ``ori_im``; returns the empty string if score < ``drop_score``."""
if device_id is None:
device_id = 0

Expand All @@ -694,6 +738,7 @@ def recognize(self, ori_im, box, device_id: int | None = None):
return text

def recognize_batch(self, img_list, device_id: int | None = None):
"""Recognise text for a list of pre-cropped images, dropping low-confidence results."""
if device_id is None:
device_id = 0
rec_res, elapse = self.text_recognizer[device_id](img_list)
Expand All @@ -706,6 +751,7 @@ def recognize_batch(self, img_list, device_id: int | None = None):
return texts

def __call__(self, img, device_id = 0, cls=True):
"""End-to-end OCR on ``img``: detect boxes, recognise each, return ``(filtered_boxes, [(text, score), ...], time_dict)``."""
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
if device_id is None:
device_id = 0
Expand Down