@@ -64,27 +64,71 @@ def _init_table_structer(self):
6464
6565 return PPTableStructurer (asdict (self .cfg ))
6666
67+ def _batch_process (
68+ self ,
69+ img_contents : List [Union [str , np .ndarray , bytes , Path ]],
70+ ocr_results : Optional [List ] = None ,
71+ batch_size : int = 4 ,
72+ ) -> List [RapidTableOutput ]:
73+ """批量处理图像"""
74+ s = time .perf_counter ()
75+
76+ images = []
77+ for img_content in img_contents :
78+ img = self .load_img (img_content )
79+ images .append (img )
80+
81+ batch_dt_boxes = []
82+ batch_rec_res = []
83+
84+ for i , img in enumerate (images ):
85+ dt_boxes , rec_res = get_boxes_recs (ocr_results [i ], img .shape [:2 ])
86+ batch_dt_boxes .append (dt_boxes )
87+ batch_rec_res .append (rec_res )
88+
89+ # 批量表格结构识别
90+ batch_results = self .table_structure (images , batch_size )
91+
92+ output_results = []
93+ for i , (img , (pred_structures , cell_bboxes , _ )) in enumerate (zip (images , batch_results )):
94+ logic_points = self .table_matcher .decode_logic_points (pred_structures )
95+ pred_html = self .get_table_matcher (
96+ pred_structures , cell_bboxes , batch_dt_boxes [i ], batch_rec_res [i ]
97+ )
98+ result = RapidTableOutput (img , pred_html , cell_bboxes , logic_points , 0 )
99+ output_results .append (result )
100+
101+ total_elapse = time .perf_counter () - s
102+ for result in output_results :
103+ result .elapse = total_elapse / len (output_results )
104+
105+ return output_results
106+
67107 def __call__ (
68- self ,
69- img_content : Union [str , np .ndarray , bytes , Path ],
70- ocr_results : Optional [Tuple [np .ndarray , Tuple [str ], Tuple [float ]]] = None ,
108+ self ,
109+ img_content : Union [str , np .ndarray , bytes , Path ],
110+ ocr_results : Optional [Tuple [np .ndarray , Tuple [str ], Tuple [float ]]] = None ,
111+ batch_size : int = 1 ,
71112 ) -> RapidTableOutput :
72- s = time .perf_counter ()
113+ if batch_size > 1 :
114+ return self ._batch_process (img_content , ocr_results )
115+ else :
116+ s = time .perf_counter ()
73117
74- img = self .load_img (img_content )
118+ img = self .load_img (img_content )
75119
76- dt_boxes , rec_res = self .get_ocr_results (img , ocr_results )
77- pred_structures , cell_bboxes , logic_points = self .get_table_rec_results (img )
120+ dt_boxes , rec_res = self .get_ocr_results (img , ocr_results )
121+ pred_structures , cell_bboxes , logic_points = self .get_table_rec_results (img )
78122
79- pred_html = self .get_table_matcher (
80- pred_structures , cell_bboxes , dt_boxes , rec_res
81- )
123+ pred_html = self .get_table_matcher (
124+ pred_structures , cell_bboxes , dt_boxes , rec_res
125+ )
82126
83- elapse = time .perf_counter () - s
84- return RapidTableOutput (img , pred_html , cell_bboxes , logic_points , elapse )
127+ elapse = time .perf_counter () - s
128+ return RapidTableOutput (img , pred_html , cell_bboxes , logic_points , elapse )
85129
86130 def get_ocr_results (
87- self , img : np .ndarray , ocr_results : Tuple [np .ndarray , Tuple [str ], Tuple [float ]]
131+ self , img : np .ndarray , ocr_results : Tuple [np .ndarray , Tuple [str ], Tuple [float ]]
88132 ) -> Tuple [Optional [np .ndarray ], Optional [np .ndarray ]]:
89133 if ocr_results is not None :
90134 return get_boxes_recs (ocr_results , img .shape [:2 ])
0 commit comments