Skip to content

Commit 6da3974

Browse files
authored
Merge pull request #119 from deeperrrr/batch_process
feat: 增加slanet-plus表格结构模型批量推理功能
2 parents 35ca6d7 + 6a57308 commit 6da3974

4 files changed

Lines changed: 211 additions & 17 deletions

File tree

batch_demo.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# -*- encoding: utf-8 -*-
2+
# @Author: deeperrrr
3+
# @Contact: 3545615231@qq.com
4+
import cv2
5+
import numpy as np
6+
from pathlib import Path
7+
from typing import List
8+
from tqdm import tqdm
9+
10+
from rapidocr import EngineType, RapidOCR
11+
from rapid_table import ModelType, RapidTable, RapidTableInput
12+
13+
ocr_engine = RapidOCR(
14+
params={
15+
"Det.engine_type": EngineType.TORCH,
16+
"Cls.engine_type": EngineType.TORCH,
17+
"Rec.engine_type": EngineType.TORCH,
18+
}
19+
)
20+
img_dir_path = "/data/images" # 图片文件夹
21+
ocr_results = []
22+
batch_size = 4
23+
24+
# input_args = RapidTableInput(model_type=ModelType.UNITABLE)
25+
input_args = RapidTableInput(model_type=ModelType.SLANETPLUS)
26+
table_engine = RapidTable(input_args)
27+
28+
29+
def load_images_original_size(img_dir: str) -> List[np.ndarray]:
30+
img_dir = Path(img_dir)
31+
if not img_dir.exists():
32+
raise FileNotFoundError(f"目录不存在: {img_dir}")
33+
34+
image_paths = []
35+
for ext in ['*.jpg', '*.jpeg', '*.png']:
36+
image_paths.extend(list(img_dir.glob(ext)))
37+
38+
images = []
39+
for img_path in tqdm(image_paths, desc="加载图像"):
40+
img = cv2.imread(str(img_path))
41+
images.append(img)
42+
return images
43+
44+
45+
def dynamic_batch_process(table_engine: RapidTable, images: List[np.ndarray], ocr_results: List[List],
46+
batch_size: int = 1):
47+
all_results = []
48+
for i in tqdm(range(0, len(images), batch_size), desc=f"表格批量推理, batch_size={batch_size}"):
49+
batch_imgs = images[i:i + batch_size]
50+
batch_ocrs = ocr_results[i:i + batch_size]
51+
results = table_engine(batch_imgs, batch_ocrs, batch_size)
52+
all_results.extend(results)
53+
return all_results
54+
55+
56+
images = load_images_original_size(img_dir_path)
57+
for img in tqdm(images, desc="OCR处理"):
58+
ori_ocr_res = ocr_engine(img)
59+
ocr_result = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores]
60+
ocr_results.append(ocr_result)
61+
62+
# 批量表格结构识别
63+
results = dynamic_batch_process(table_engine, images, ocr_results, batch_size) # batch_size默认4
64+
65+
for i, result in enumerate(results):
66+
result.vis(save_dir="outputs", save_name=f"vis_{i}")

rapid_table/main.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,27 +64,71 @@ def _init_table_structer(self):
6464

6565
return PPTableStructurer(asdict(self.cfg))
6666

67+
def _batch_process(
68+
self,
69+
img_contents: List[Union[str, np.ndarray, bytes, Path]],
70+
ocr_results: Optional[List] = None,
71+
batch_size: int = 4,
72+
) -> List[RapidTableOutput]:
73+
"""批量处理图像"""
74+
s = time.perf_counter()
75+
76+
images = []
77+
for img_content in img_contents:
78+
img = self.load_img(img_content)
79+
images.append(img)
80+
81+
batch_dt_boxes = []
82+
batch_rec_res = []
83+
84+
for i, img in enumerate(images):
85+
dt_boxes, rec_res = get_boxes_recs(ocr_results[i], img.shape[:2])
86+
batch_dt_boxes.append(dt_boxes)
87+
batch_rec_res.append(rec_res)
88+
89+
# 批量表格结构识别
90+
batch_results = self.table_structure(images, batch_size)
91+
92+
output_results = []
93+
for i, (img, (pred_structures, cell_bboxes, _)) in enumerate(zip(images, batch_results)):
94+
logic_points = self.table_matcher.decode_logic_points(pred_structures)
95+
pred_html = self.get_table_matcher(
96+
pred_structures, cell_bboxes, batch_dt_boxes[i], batch_rec_res[i]
97+
)
98+
result = RapidTableOutput(img, pred_html, cell_bboxes, logic_points, 0)
99+
output_results.append(result)
100+
101+
total_elapse = time.perf_counter() - s
102+
for result in output_results:
103+
result.elapse = total_elapse / len(output_results)
104+
105+
return output_results
106+
67107
def __call__(
68-
self,
69-
img_content: Union[str, np.ndarray, bytes, Path],
70-
ocr_results: Optional[Tuple[np.ndarray, Tuple[str], Tuple[float]]] = None,
108+
self,
109+
img_content: Union[str, np.ndarray, bytes, Path],
110+
ocr_results: Optional[Tuple[np.ndarray, Tuple[str], Tuple[float]]] = None,
111+
batch_size: int = 1,
71112
) -> RapidTableOutput:
72-
s = time.perf_counter()
113+
if batch_size > 1:
114+
return self._batch_process(img_content, ocr_results)
115+
else:
116+
s = time.perf_counter()
73117

74-
img = self.load_img(img_content)
118+
img = self.load_img(img_content)
75119

76-
dt_boxes, rec_res = self.get_ocr_results(img, ocr_results)
77-
pred_structures, cell_bboxes, logic_points = self.get_table_rec_results(img)
120+
dt_boxes, rec_res = self.get_ocr_results(img, ocr_results)
121+
pred_structures, cell_bboxes, logic_points = self.get_table_rec_results(img)
78122

79-
pred_html = self.get_table_matcher(
80-
pred_structures, cell_bboxes, dt_boxes, rec_res
81-
)
123+
pred_html = self.get_table_matcher(
124+
pred_structures, cell_bboxes, dt_boxes, rec_res
125+
)
82126

83-
elapse = time.perf_counter() - s
84-
return RapidTableOutput(img, pred_html, cell_bboxes, logic_points, elapse)
127+
elapse = time.perf_counter() - s
128+
return RapidTableOutput(img, pred_html, cell_bboxes, logic_points, elapse)
85129

86130
def get_ocr_results(
87-
self, img: np.ndarray, ocr_results: Tuple[np.ndarray, Tuple[str], Tuple[float]]
131+
self, img: np.ndarray, ocr_results: Tuple[np.ndarray, Tuple[str], Tuple[float]]
88132
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
89133
if ocr_results is not None:
90134
return get_boxes_recs(ocr_results, img.shape[:2])

rapid_table/table_structure/pp_structure/main.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ...inference_engine.base import get_engine
2222
from ..utils import get_struct_str
2323
from .post_process import TableLabelDecode
24-
from .pre_process import TablePreprocess
24+
from .pre_process import TablePreprocess, BatchTablePreprocess
2525

2626

2727
class PPTableStructurer:
@@ -32,11 +32,15 @@ def __init__(self, cfg: Dict[str, Any]):
3232
self.cfg = cfg
3333

3434
self.preprocess_op = TablePreprocess()
35+
self.batch_preprocess_op = BatchTablePreprocess()
3536

3637
self.character = self.session.get_character_list()
3738
self.postprocess_op = TableLabelDecode(self.character)
3839

39-
def __call__(self, ori_img: np.ndarray) -> Tuple[List[str], np.ndarray, float]:
40+
def __call__(self, ori_img: np.ndarray, batch_size: int = 4) -> Tuple[List[str], np.ndarray, float]:
41+
if batch_size > 1:
42+
return self.batch_process(ori_img)
43+
4044
s = time.perf_counter()
4145

4246
img, shape_list = self.preprocess_op(ori_img)
@@ -55,8 +59,52 @@ def __call__(self, ori_img: np.ndarray) -> Tuple[List[str], np.ndarray, float]:
5559
elapse = time.perf_counter() - s
5660
return table_struct_str, cell_bboxes, elapse
5761

62+
def batch_process(self, img_list: List[np.ndarray]) -> List[Tuple[List[str], np.ndarray, float]]:
63+
"""批量处理图像列表
64+
Args:
65+
img_list: 图像列表
66+
67+
Returns:
68+
结果列表,每个元素包含 (table_struct_str, cell_bboxes, elapse)
69+
"""
70+
starttime = time.perf_counter()
71+
72+
batch_data = self.batch_preprocess_op(img_list)
73+
74+
preprocessed_images = batch_data[0]
75+
shape_lists = batch_data[1]
76+
preprocessed_images = np.array(preprocessed_images)
77+
78+
bbox_preds, struct_probs = self.session(preprocessed_images)
79+
80+
batch_size = preprocessed_images.shape[0]
81+
results = []
82+
83+
for i in range(batch_size):
84+
single_bbox_preds = bbox_preds[i:i + 1]
85+
single_struct_probs = struct_probs[i:i + 1]
86+
single_shape_list = np.array([shape_lists[i]])
87+
88+
post_result = self.postprocess_op(single_bbox_preds, single_struct_probs, [single_shape_list])
89+
90+
table_struct_str = get_struct_str(post_result["structure_batch_list"][0][0])
91+
cell_bboxes = post_result["bbox_batch_list"][0]
92+
93+
if self.cfg["model_type"] == ModelType.SLANETPLUS:
94+
cell_bboxes = self.rescale_cell_bboxes(img_list[i], cell_bboxes)
95+
96+
cell_bboxes = self.filter_blank_bbox(cell_bboxes)
97+
98+
results.append((table_struct_str, cell_bboxes, 0))
99+
100+
total_elapse = time.perf_counter() - starttime
101+
for i in range(len(results)):
102+
results[i] = (results[i][0], results[i][1], total_elapse / batch_size)
103+
104+
return results
105+
58106
def rescale_cell_bboxes(
59-
self, img: np.ndarray, cell_bboxes: np.ndarray
107+
self, img: np.ndarray, cell_bboxes: np.ndarray
60108
) -> np.ndarray:
61109
h, w = img.shape[:2]
62110
resized = 488

rapid_table/table_structure/pp_structure/pre_process.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def normalize(self, img: np.ndarray) -> np.ndarray:
3737
return (img.astype("float32") * self.scale - self.mean) / self.std
3838

3939
def pad_img(
40-
self, img: np.ndarray, shape: List[float]
40+
self, img: np.ndarray, shape: List[float]
4141
) -> Tuple[np.ndarray, List[float]]:
4242
padding_img = np.zeros((self.max_len, self.max_len, 3), dtype=np.float32)
4343
h, w = img.shape[:2]
@@ -47,3 +47,39 @@ def pad_img(
4747

4848
def to_chw(self, img: np.ndarray) -> np.ndarray:
4949
return img.transpose((2, 0, 1))
50+
51+
52+
class BatchTablePreprocess:
53+
"""批量表格预处理类"""
54+
55+
def __init__(self):
56+
self.preprocess = TablePreprocess()
57+
58+
def __call__(self, img_list: List[np.ndarray]) -> Tuple[List[np.ndarray], List[List[float]]]:
59+
"""批量处理图像
60+
61+
Args:
62+
img_list: 图像列表
63+
64+
Returns:
65+
预处理后的图像列表和形状信息列表
66+
"""
67+
if not img_list:
68+
return None, None
69+
70+
processed_imgs = []
71+
shape_lists = []
72+
73+
for img in img_list:
74+
if img is None:
75+
continue
76+
77+
img_processed, shape_list = self.preprocess.resize_image(img)
78+
img_processed = self.preprocess.normalize(img_processed)
79+
img_processed, shape_list = self.preprocess.pad_img(img_processed, shape_list)
80+
img_processed = self.preprocess.to_chw(img_processed)
81+
82+
processed_imgs.append(img_processed)
83+
shape_lists.append(shape_list)
84+
85+
return processed_imgs, shape_lists

0 commit comments

Comments
 (0)