1+ # LeanEval/runner/local_runner.py
2+ import os
3+ import torch
4+ from pathlib import Path
5+ from time import time , strftime
6+ from typing import List , Dict , Any , Tuple
7+ from tqdm import tqdm
8+ import json
9+ import concurrent .futures
10+
11+ import accelerate
12+ from accelerate import Accelerator
13+ from torch .utils .data import DataLoader
14+
15+ from LeanEval .datasets import JsonDataset , LeanItem
16+ from LeanEval .prompt import get_builder , PromptBuilder
17+ from LeanEval .models import ModelRegistry , HuggingFaceModelConfig , HuggingFaceModel
18+ from LeanEval .validator .proof_validator import ProofValidator
19+ from LeanEval .utils import extract_lean_block
20+
21+ class LocalHuggingFaceRunner :
22+ """
23+ 封装使用 Accelerate 进行本地 Hugging Face 模型多卡推理的逻辑。
24+ """
25+ def __init__ (
26+ self ,
27+ model_id : str ,
28+ dataset_path : str ,
29+ output_dir_base : str = "./outputs_local_runner" ,
30+ prompt_template : str = "Complete the Lean 4 proof below. Only output the Lean code for the complete proof\n ```lean\n {code_block_statement} := by\n ```" ,
31+ per_device_batch_size : int = 1 ,
32+ dataloader_num_workers : int = 2 ,
33+ max_new_tokens : int = 1024 ,
34+ temperature : float = 0.1 ,
35+ mixed_precision : str = 'fp16' , # 'no', 'fp16', 'bf16'
36+ validation_timeout : int = 60 ,
37+ hf_config_overrides : Dict [str , Any ] = None # 允许覆盖 HF 配置
38+ ):
39+ self .model_id = model_id
40+ self .dataset_path = dataset_path
41+ self .output_dir_base = output_dir_base
42+ self .prompt_template = prompt_template
43+ self .per_device_batch_size = per_device_batch_size
44+ self .dataloader_num_workers = dataloader_num_workers
45+ self .max_new_tokens = max_new_tokens
46+ self .temperature = temperature
47+ self .mixed_precision = mixed_precision if torch .cuda .is_available () else 'no'
48+ self .validation_timeout = validation_timeout
49+ self .hf_config_overrides = hf_config_overrides or {}
50+
51+ self .accelerator = Accelerator (mixed_precision = self .mixed_precision )
52+ self .device = self .accelerator .device
53+
54+ # 设置输出目录
55+ model_short_name = self .model_id .split ('/' )[- 1 ]
56+ self .output_dir = Path (self .output_dir_base ) / f"{ model_short_name } _{ strftime ('%Y%m%d-%H%M%S' )} "
57+ self .proof_save_dir = self .output_dir / "proofs"
58+ self .results_log_file = self .output_dir / "results.json" # 保存为 JSON
59+
60+ if self .accelerator .is_main_process :
61+ self .proof_save_dir .mkdir (parents = True , exist_ok = True )
62+ print (f"Runner initialized. Outputs will be saved to: { self .output_dir .resolve ()} " )
63+
64+ def _setup_hf_config (self ) -> HuggingFaceModelConfig :
65+ """设置 HuggingFace 模型配置。"""
66+ base_config = {
67+ "model_name" : self .model_id ,
68+ "device" : str (self .device ), # 虽然 accelerator 会处理,但初始加载可能需要
69+ "torch_dtype" : "bfloat16" if torch .cuda .is_available () and torch .cuda .is_bf16_supported () else \
70+ ("float16" if torch .cuda .is_available () else "auto" ),
71+ "trust_remote_code" : True ,
72+ "use_fast_tokenizer" : True ,
73+ "generation_kwargs" : {
74+ "max_new_tokens" : self .max_new_tokens ,
75+ "temperature" : self .temperature ,
76+ "top_p" : 0.95 ,
77+ "do_sample" : True ,
78+ }
79+ }
80+ # 应用覆盖
81+ base_config .update (self .hf_config_overrides )
82+ # 确保 generation_kwargs 被正确更新
83+ if "generation_kwargs" in self .hf_config_overrides :
84+ base_config ["generation_kwargs" ].update (self .hf_config_overrides ["generation_kwargs" ])
85+
86+ return HuggingFaceModelConfig (** base_config )
87+
88+ def _setup_dataloader (self ) -> DataLoader :
89+ """设置数据集和 DataLoader。"""
90+ dataset = JsonDataset (self .dataset_path )
91+ prompt_builder = get_builder ("simple" , template = self .prompt_template )
92+
93+ def custom_collate_fn (batch_items : List [LeanItem ]) -> Dict [str , Any ]:
94+ prompts_for_model = []
95+ original_items_metadata = []
96+ for item in batch_items :
97+ code_block_stmt = item .prompt_ready_stmt
98+ prompt_str = prompt_builder .template .format (code_block_statement = code_block_stmt )
99+ prompts_for_model .append (prompt_str )
100+ original_items_metadata .append ({
101+ "id" : item .id ,
102+ "prompt_ready_stmt" : item .prompt_ready_stmt ,
103+ "imports_txt" :item .imports_txt
104+ })
105+ return {"prompts_for_model" : prompts_for_model , "original_items_metadata" : original_items_metadata }
106+
107+ return DataLoader (
108+ dataset ,
109+ batch_size = self .per_device_batch_size ,
110+ shuffle = False ,
111+ num_workers = self .dataloader_num_workers ,
112+ collate_fn = custom_collate_fn ,
113+ pin_memory = True
114+ )
115+
116+ def _run_validation (self , saved_files : List [Path ]) -> List [Dict ]:
117+ """运行验证并返回结果。"""
118+ if not saved_files :
119+ return []
120+
121+ self .accelerator .print (f"Starting validation for { len (saved_files )} proofs..." )
122+ validator = ProofValidator (timeout = self .validation_timeout )
123+
124+ passed_files , failed_files_with_msg = [], []
125+
126+ def validate_task (filepath ):
127+ success , msg = validator .validate_file (filepath )
128+ return filepath , success , msg
129+
130+ num_workers = max (1 , os .cpu_count () // 2 ) if os .cpu_count () else 4
131+ with concurrent .futures .ThreadPoolExecutor (max_workers = num_workers ) as executor :
132+ future_to_filepath = {
133+ executor .submit (validate_task , filepath ): filepath
134+ for filepath in saved_files
135+ }
136+ for future in tqdm (
137+ concurrent .futures .as_completed (future_to_filepath ),
138+ total = len (saved_files ),
139+ desc = "Validating" ,
140+ disable = not self .accelerator .is_main_process
141+ ):
142+ filepath , success , msg = future .result ()
143+ if success :
144+ passed_files .append (filepath )
145+ else :
146+ failed_files_with_msg .append ((filepath , msg ))
147+
148+ self .accelerator .print (f"Validation: Passed={ len (passed_files )} , Failed={ len (failed_files_with_msg )} " )
149+
150+ results_map = {f : {"status" : "Passed" , "log" : "" } for f in passed_files }
151+ results_map .update ({f : {"status" : "Failed" , "log" : msg } for f , msg in failed_files_with_msg })
152+
153+ return results_map
154+
155+
156+ def run (self ):
157+ """执行完整的推理和验证流程。"""
158+ start_time = time ()
159+
160+ # 1. 设置
161+ hf_config = self ._setup_hf_config ()
162+ eval_dataloader = self ._setup_dataloader ()
163+
164+ # 2. 加载模型
165+ with ModelRegistry .create ("huggingface" , ** hf_config .model_dump ()) as hf_model :
166+ if hf_model .tokenizer .pad_token_id is None :
167+ hf_model .tokenizer .pad_token_id = hf_model .tokenizer .eos_token_id
168+ hf_model .model .config .pad_token_id = hf_model .tokenizer .eos_token_id
169+
170+ # 3. Accelerator Prepare
171+ prepared_model , prepared_dataloader = self .accelerator .prepare (
172+ hf_model .model , eval_dataloader
173+ )
174+ hf_model .model = prepared_model # 更新包装器中的模型引用
175+
176+ # 4. 推理
177+ current_process_outputs = []
178+ if hasattr (hf_model .model , 'eval' ):
179+ hf_model .model .eval ()
180+
181+ progress_bar = tqdm (
182+ prepared_dataloader ,
183+ desc = f"Inference (Proc { self .accelerator .process_index } )" ,
184+ disable = not self .accelerator .is_local_main_process
185+ )
186+
187+ with torch .no_grad ():
188+ for batch_data in progress_bar :
189+ prompts = batch_data ["prompts_for_model" ]
190+ metadata = batch_data ["original_items_metadata" ]
191+
192+ generated_batch_texts = hf_model .batch_predict (
193+ prompts ,
194+ batch_size = len (prompts )
195+ )
196+
197+ for i ,gen_text in enumerate (generated_batch_texts ):
198+ current_process_outputs .append ({
199+ "id" :metadata [i ]["id" ],
200+ "generated_proof_part" :gen_text .strip (),
201+ "prompt_ready_stmt" :metadata [i ]["prompt_ready_stmt" ],
202+ "imports_txt" :metadata [i ]["imports_txt" ],
203+ "process_index" :self .accelerator .process_index
204+ })
205+
206+ self .accelerator .wait_for_everyone ()
207+
208+ # 5. 收集与保存
209+ gathered_outputs = accelerate .utils .gather_object (current_process_outputs )
210+ saved_files_map = {} # id -> Path
211+ if self .accelerator .is_main_process :
212+ print (f"Gathered { len (gathered_outputs )} results. Saving proofs..." )
213+ for output_data in tqdm (gathered_outputs , desc = "Saving proofs" ):
214+ item_id = output_data ["id" ]
215+ proof_part = output_data ["generated_proof_part" ]
216+ stmt = output_data ["prompt_ready_stmt" ]
217+ imports = output_data ["imports_txt" ]
218+ # 尝试提取 ```lean ... ```, 否则用原始输出
219+ extracted_code = extract_lean_block (proof_part ) or proof_part
220+ full_code = f"{ imports } \n { extracted_code } "
221+
222+ safe_id = str (item_id ).replace ("/" , "_" ).replace ("\\ " , "_" )
223+ proof_file = self .proof_save_dir / f"{ safe_id } .lean"
224+ proof_file .write_text (full_code , encoding = "utf-8" )
225+ saved_files_map [item_id ] = proof_file
226+
227+ # 6. 验证 (仅主进程)
228+ validation_results = self ._run_validation (list (saved_files_map .values ()))
229+
230+ # 7. 组合最终结果
231+ final_results = []
232+ for res in gathered_outputs :
233+ item_id = res ['id' ]
234+ path = saved_files_map .get (item_id )
235+ val_res = validation_results .get (path , {"status" : "Unknown" , "log" : "File not found in validation map" })
236+ res .update ({
237+ "proof_file" : str (path ),
238+ "validation_status" : val_res ["status" ],
239+ "validation_log" : val_res ["log" ]
240+ })
241+ final_results .append (res )
242+
243+ # 8. 保存 JSON 结果
244+ with open (self .results_log_file , "w" , encoding = "utf-8" ) as f :
245+ json .dump (final_results , f , indent = 2 , ensure_ascii = False )
246+
247+ print (f"Runner finished. Total time: { time () - start_time :.2f} s. Results saved to { self .results_log_file } " )
248+
249+ self .accelerator .wait_for_everyone ()
0 commit comments