Skip to content

Commit aa0fd63

Browse files
authored
Merge pull request #73 from RapidAI/develop
chore: optimize code
2 parents 9e3816e + aad22d9 commit aa0fd63

12 files changed

Lines changed: 370 additions & 250 deletions

File tree

demo.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from rapid_table import RapidTable, RapidTableInput, VisTable
99

10-
if __name__ == '__main__':
10+
if __name__ == "__main__":
1111
# Init
1212
ocr_engine = RapidOCR()
1313
vis_ocr = VisRes()
@@ -16,32 +16,45 @@
1616
table_engine = RapidTable(input_args)
1717
viser = VisTable()
1818

19-
img_path = "tests/test_files/table.jpg"
19+
img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
2020

2121
# OCR
22-
23-
rapid_ocr_output = ocr_engine(img_path, return_word_box=True)
22+
rapid_ocr_output = ocr_engine(img_path)
2423
ocr_result = list(
2524
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
2625
)
2726
table_results = table_engine(img_path, ocr_result)
27+
2828
# 使用单字识别
2929
# word_results = rapid_ocr_output.word_results
3030
# ocr_result = [
3131
# [word_result[2], word_result[0], word_result[1]] for word_result in word_results
3232
# ]
3333
# table_results = table_engine(img_path, ocr_result)
3434

35-
table_html_str, table_cell_bboxes = table_results.pred_html, table_results.cell_bboxes
35+
table_html_str, table_cell_bboxes = (
36+
table_results.pred_html,
37+
table_results.cell_bboxes,
38+
)
3639
# Save
3740
save_dir = Path("outputs")
3841
save_dir.mkdir(parents=True, exist_ok=True)
3942

4043
save_html_path = save_dir / f"{Path(img_path).stem}.html"
41-
save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}"
42-
save_logic_points_path = save_dir / f"{Path(img_path).stem}_table_col_row_vis{Path(img_path).suffix}"
44+
save_drawed_path = (
45+
save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}"
46+
)
47+
save_logic_points_path = (
48+
save_dir / f"{Path(img_path).stem}_table_col_row_vis{Path(img_path).suffix}"
49+
)
4350

4451
# Visualize table rec result
45-
vis_imged = viser(img_path, table_results, save_html_path, save_drawed_path, save_logic_points_path)
52+
vis_imged = viser(
53+
img_path,
54+
table_results,
55+
save_html_path,
56+
save_drawed_path,
57+
save_logic_points_path,
58+
)
4659

4760
print(f"The results has been saved {save_dir}")

rapid_table/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# @Author: SWHL
33
# @Contact: liekkaskono@163.com
44
from .main import RapidTable, RapidTableInput
5-
from .utils.utils import VisTable
5+
from .utils import VisTable

rapid_table/main.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@
1313
import cv2
1414
import numpy as np
1515

16-
from rapid_table.utils.download_model import DownloadModel
17-
from rapid_table.utils.logger import get_logger
18-
from rapid_table.utils.utils import LoadImage, VisTable
16+
from rapid_table.utils import DownloadModel, LoadImage, Logger, VisTable
1917

2018
from .table_matcher import TableMatch
2119
from .table_structure import TableStructurer, TableStructureUnitable
2220

23-
logger = get_logger("main")
21+
logger = Logger(logger_name=__name__).get_log()
2422
root_dir = Path(__file__).resolve().parent
2523

2624

@@ -78,7 +76,7 @@ def __init__(self, config: RapidTableInput):
7876
self.table_matcher = TableMatch()
7977

8078
try:
81-
self.ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
79+
self.ocr_engine = importlib.import_module("rapidocr").RapidOCR()
8280
except ModuleNotFoundError:
8381
self.ocr_engine = None
8482

@@ -91,7 +89,7 @@ def __call__(
9189
) -> RapidTableOutput:
9290
if self.ocr_engine is None and ocr_result is None:
9391
raise ValueError(
94-
"One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
92+
"One of two conditions must be met: ocr_result is not empty, or rapidocr is installed."
9593
)
9694

9795
img = self.load_img(img_content)
@@ -100,7 +98,14 @@ def __call__(
10098
h, w = img.shape[:2]
10199

102100
if ocr_result is None:
103-
ocr_result, _ = self.ocr_engine(img)
101+
ocr_result = self.ocr_engine(img)
102+
ocr_result = list(
103+
zip(
104+
ocr_result.boxes,
105+
ocr_result.txts,
106+
ocr_result.scores,
107+
)
108+
)
104109
dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
105110

106111
pred_structures, cell_bboxes, _ = self.table_structure(copy.deepcopy(img))
@@ -197,18 +202,21 @@ def main(arg_list: Optional[List[str]] = None):
197202
args = parse_args(arg_list)
198203

199204
try:
200-
ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
205+
ocr_engine = importlib.import_module("rapidocr").RapidOCR()
201206
except ModuleNotFoundError as exc:
202207
raise ModuleNotFoundError(
203-
"Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime."
208+
"Please install the rapidocr by pip install rapidocr"
204209
) from exc
205210

206211
input_args = RapidTableInput(model_type=args.model_type)
207212
table_engine = RapidTable(input_args)
208213

209214
img = cv2.imread(args.img_path)
210215

211-
ocr_result, _ = ocr_engine(img)
216+
rapid_ocr_output = ocr_engine(img)
217+
ocr_result = list(
218+
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
219+
)
212220
table_results = table_engine(img, ocr_result)
213221
print(table_results.pred_html)
214222

rapid_table/table_structure/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
get_device,
3232
)
3333

34-
from rapid_table.utils.logger import get_logger
34+
from rapid_table.utils import Logger
3535

3636

3737
class EP(Enum):
@@ -42,7 +42,7 @@ class EP(Enum):
4242

4343
class OrtInferSession:
4444
def __init__(self, config: Dict[str, Any]):
45-
self.logger = get_logger("OrtInferSession")
45+
self.logger = Logger(logger_name=__name__).get_log()
4646

4747
model_path = config.get("model_path", None)
4848
self._verify_model(model_path)

rapid_table/utils/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
# -*- encoding: utf-8 -*-
22
# @Author: SWHL
33
# @Contact: liekkaskono@163.com
4+
from .download_model import DownloadModel
5+
from .load_image import LoadImage
6+
from .logger import Logger
7+
from .utils import is_url
8+
from .vis import VisTable

rapid_table/utils/download_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
import requests
66
from tqdm import tqdm
77

8-
from .logger import get_logger
9-
10-
logger = get_logger("DownloadModel")
8+
from .logger import Logger
119

1210
PROJECT_DIR = Path(__file__).resolve().parent.parent
1311
DEFAULT_MODEL_DIR = PROJECT_DIR / "models"
1412

1513

1614
class DownloadModel:
15+
logger = Logger(logger_name=__name__).get_log()
16+
1717
@classmethod
1818
def download(
1919
cls,
@@ -31,11 +31,11 @@ def download(
3131

3232
save_file_path = save_dir / save_model_name
3333
if save_file_path.exists():
34-
logger.debug("%s already exists", save_file_path)
34+
cls.logger.info("%s already exists", save_file_path)
3535
return str(save_file_path)
3636

3737
try:
38-
logger.info("Download %s to %s", model_full_url, save_dir)
38+
cls.logger.info("Download %s to %s", model_full_url, save_dir)
3939
file = cls.download_as_bytes_with_progress(model_full_url, save_model_name)
4040
cls.save_file(save_file_path, file)
4141
except Exception as exc:

rapid_table/utils/load_image.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# -*- encoding: utf-8 -*-
2+
# @Author: SWHL
3+
# @Contact: liekkaskono@163.com
4+
from io import BytesIO
5+
from pathlib import Path
6+
from typing import Any, Union
7+
8+
import cv2
9+
import numpy as np
10+
import requests
11+
from PIL import Image, UnidentifiedImageError
12+
13+
from .utils import is_url
14+
15+
root_dir = Path(__file__).resolve().parent
16+
InputType = Union[str, np.ndarray, bytes, Path, Image.Image]
17+
18+
19+
class LoadImage:
20+
def __init__(self):
21+
pass
22+
23+
def __call__(self, img: InputType) -> np.ndarray:
24+
if not isinstance(img, InputType.__args__):
25+
raise LoadImageError(
26+
f"The img type {type(img)} does not in {InputType.__args__}"
27+
)
28+
29+
origin_img_type = type(img)
30+
img = self.load_img(img)
31+
img = self.convert_img(img, origin_img_type)
32+
return img
33+
34+
def load_img(self, img: InputType) -> np.ndarray:
35+
if isinstance(img, (str, Path)):
36+
if is_url(img):
37+
img = Image.open(requests.get(img, stream=True, timeout=60).raw)
38+
else:
39+
self.verify_exist(img)
40+
img = Image.open(img)
41+
42+
try:
43+
img = self.img_to_ndarray(img)
44+
except UnidentifiedImageError as e:
45+
raise LoadImageError(f"cannot identify image file {img}") from e
46+
return img
47+
48+
if isinstance(img, bytes):
49+
img = self.img_to_ndarray(Image.open(BytesIO(img)))
50+
return img
51+
52+
if isinstance(img, np.ndarray):
53+
return img
54+
55+
if isinstance(img, Image.Image):
56+
return self.img_to_ndarray(img)
57+
58+
raise LoadImageError(f"{type(img)} is not supported!")
59+
60+
def img_to_ndarray(self, img: Image.Image) -> np.ndarray:
61+
if img.mode == "1":
62+
img = img.convert("L")
63+
return np.array(img)
64+
return np.array(img)
65+
66+
def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray:
67+
if img.ndim == 2:
68+
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
69+
70+
if img.ndim == 3:
71+
channel = img.shape[2]
72+
if channel == 1:
73+
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
74+
75+
if channel == 2:
76+
return self.cvt_two_to_three(img)
77+
78+
if channel == 3:
79+
if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
80+
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
81+
return img
82+
83+
if channel == 4:
84+
return self.cvt_four_to_three(img)
85+
86+
raise LoadImageError(
87+
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
88+
)
89+
90+
raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
91+
92+
@staticmethod
93+
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
94+
"""gray + alpha → BGR"""
95+
img_gray = img[..., 0]
96+
img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
97+
98+
img_alpha = img[..., 1]
99+
not_a = cv2.bitwise_not(img_alpha)
100+
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
101+
102+
new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
103+
new_img = cv2.add(new_img, not_a)
104+
return new_img
105+
106+
@staticmethod
107+
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
108+
"""RGBA → BGR"""
109+
r, g, b, a = cv2.split(img)
110+
new_img = cv2.merge((b, g, r))
111+
112+
not_a = cv2.bitwise_not(a)
113+
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
114+
115+
new_img = cv2.bitwise_and(new_img, new_img, mask=a)
116+
117+
mean_color = np.mean(new_img)
118+
if mean_color <= 0.0:
119+
new_img = cv2.add(new_img, not_a)
120+
else:
121+
new_img = cv2.bitwise_not(new_img)
122+
return new_img
123+
124+
@staticmethod
125+
def verify_exist(file_path: Union[str, Path]):
126+
if not Path(file_path).exists():
127+
raise LoadImageError(f"{file_path} does not exist.")
128+
129+
130+
class LoadImageError(Exception):
131+
pass

rapid_table/utils/logger.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,36 @@
22
# @Author: Jocker1212
33
# @Contact: xinyijianggo@gmail.com
44
import logging
5-
from functools import lru_cache
65

6+
import colorlog
77

8-
@lru_cache(maxsize=32)
9-
def get_logger(name: str) -> logging.Logger:
10-
logger = logging.getLogger(name)
11-
logger.setLevel(logging.DEBUG)
128

13-
fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s"
14-
format_str = logging.Formatter(fmt)
9+
class Logger:
10+
def __init__(self, log_level=logging.DEBUG, logger_name=None):
11+
self.logger = logging.getLogger(logger_name)
12+
self.logger.setLevel(log_level)
13+
self.logger.propagate = False
1514

16-
sh = logging.StreamHandler()
17-
sh.setLevel(logging.DEBUG)
15+
formatter = colorlog.ColoredFormatter(
16+
"%(log_color)s[%(levelname)s] %(asctime)s [RapidTable] %(filename)s:%(lineno)d: %(message)s",
17+
log_colors={
18+
"DEBUG": "cyan",
19+
"INFO": "green",
20+
"WARNING": "yellow",
21+
"ERROR": "red",
22+
"CRITICAL": "red,bg_white",
23+
},
24+
)
1825

19-
logger.addHandler(sh)
20-
sh.setFormatter(format_str)
21-
return logger
26+
if not self.logger.handlers:
27+
console_handler = logging.StreamHandler()
28+
console_handler.setFormatter(formatter)
29+
30+
for handler in self.logger.handlers:
31+
self.logger.removeHandler(handler)
32+
33+
console_handler.setLevel(log_level)
34+
self.logger.addHandler(console_handler)
35+
36+
def get_log(self):
37+
return self.logger

0 commit comments

Comments
 (0)