Skip to content

Commit ca0066c

Browse files
committed
add runner module
1 parent 5fa70cd commit ca0066c

2 files changed

Lines changed: 250 additions & 0 deletions

File tree

LeanEval/runner/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .local_runner import LocalHuggingFaceRunner

LeanEval/runner/local_runner.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# LeanEval/runner/local_runner.py
2+
import os
3+
import torch
4+
from pathlib import Path
5+
from time import time, strftime
6+
from typing import List, Dict, Any, Tuple
7+
from tqdm import tqdm
8+
import json
9+
import concurrent.futures
10+
11+
import accelerate
12+
from accelerate import Accelerator
13+
from torch.utils.data import DataLoader
14+
15+
from LeanEval.datasets import JsonDataset, LeanItem
16+
from LeanEval.prompt import get_builder, PromptBuilder
17+
from LeanEval.models import ModelRegistry, HuggingFaceModelConfig, HuggingFaceModel
18+
from LeanEval.validator.proof_validator import ProofValidator
19+
from LeanEval.utils import extract_lean_block
20+
21+
class LocalHuggingFaceRunner:
22+
"""
23+
封装使用 Accelerate 进行本地 Hugging Face 模型多卡推理的逻辑。
24+
"""
25+
def __init__(
26+
self,
27+
model_id: str,
28+
dataset_path: str,
29+
output_dir_base: str = "./outputs_local_runner",
30+
prompt_template: str = "Complete the Lean 4 proof below. Only output the Lean code for the complete proof\n```lean\n{code_block_statement} := by\n```",
31+
per_device_batch_size: int = 1,
32+
dataloader_num_workers: int = 2,
33+
max_new_tokens: int = 1024,
34+
temperature: float = 0.1,
35+
mixed_precision: str = 'fp16', # 'no', 'fp16', 'bf16'
36+
validation_timeout: int = 60,
37+
hf_config_overrides: Dict[str, Any] = None # 允许覆盖 HF 配置
38+
):
39+
self.model_id = model_id
40+
self.dataset_path = dataset_path
41+
self.output_dir_base = output_dir_base
42+
self.prompt_template = prompt_template
43+
self.per_device_batch_size = per_device_batch_size
44+
self.dataloader_num_workers = dataloader_num_workers
45+
self.max_new_tokens = max_new_tokens
46+
self.temperature = temperature
47+
self.mixed_precision = mixed_precision if torch.cuda.is_available() else 'no'
48+
self.validation_timeout = validation_timeout
49+
self.hf_config_overrides = hf_config_overrides or {}
50+
51+
self.accelerator = Accelerator(mixed_precision=self.mixed_precision)
52+
self.device = self.accelerator.device
53+
54+
# 设置输出目录
55+
model_short_name = self.model_id.split('/')[-1]
56+
self.output_dir = Path(self.output_dir_base) / f"{model_short_name}_{strftime('%Y%m%d-%H%M%S')}"
57+
self.proof_save_dir = self.output_dir / "proofs"
58+
self.results_log_file = self.output_dir / "results.json" # 保存为 JSON
59+
60+
if self.accelerator.is_main_process:
61+
self.proof_save_dir.mkdir(parents=True, exist_ok=True)
62+
print(f"Runner initialized. Outputs will be saved to: {self.output_dir.resolve()}")
63+
64+
def _setup_hf_config(self) -> HuggingFaceModelConfig:
65+
"""设置 HuggingFace 模型配置。"""
66+
base_config = {
67+
"model_name": self.model_id,
68+
"device": str(self.device), # 虽然 accelerator 会处理,但初始加载可能需要
69+
"torch_dtype": "bfloat16" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else \
70+
("float16" if torch.cuda.is_available() else "auto"),
71+
"trust_remote_code": True,
72+
"use_fast_tokenizer": True,
73+
"generation_kwargs": {
74+
"max_new_tokens": self.max_new_tokens,
75+
"temperature": self.temperature,
76+
"top_p": 0.95,
77+
"do_sample": True,
78+
}
79+
}
80+
# 应用覆盖
81+
base_config.update(self.hf_config_overrides)
82+
# 确保 generation_kwargs 被正确更新
83+
if "generation_kwargs" in self.hf_config_overrides:
84+
base_config["generation_kwargs"].update(self.hf_config_overrides["generation_kwargs"])
85+
86+
return HuggingFaceModelConfig(**base_config)
87+
88+
def _setup_dataloader(self) -> DataLoader:
89+
"""设置数据集和 DataLoader。"""
90+
dataset = JsonDataset(self.dataset_path)
91+
prompt_builder = get_builder("simple", template=self.prompt_template)
92+
93+
def custom_collate_fn(batch_items: List[LeanItem]) -> Dict[str, Any]:
94+
prompts_for_model = []
95+
original_items_metadata = []
96+
for item in batch_items:
97+
code_block_stmt = item.prompt_ready_stmt
98+
prompt_str = prompt_builder.template.format(code_block_statement=code_block_stmt)
99+
prompts_for_model.append(prompt_str)
100+
original_items_metadata.append({
101+
"id": item.id,
102+
"prompt_ready_stmt": item.prompt_ready_stmt,
103+
"imports_txt":item.imports_txt
104+
})
105+
return {"prompts_for_model": prompts_for_model, "original_items_metadata": original_items_metadata}
106+
107+
return DataLoader(
108+
dataset,
109+
batch_size=self.per_device_batch_size,
110+
shuffle=False,
111+
num_workers=self.dataloader_num_workers,
112+
collate_fn=custom_collate_fn,
113+
pin_memory=True
114+
)
115+
116+
def _run_validation(self, saved_files: List[Path]) -> List[Dict]:
117+
"""运行验证并返回结果。"""
118+
if not saved_files:
119+
return []
120+
121+
self.accelerator.print(f"Starting validation for {len(saved_files)} proofs...")
122+
validator = ProofValidator(timeout=self.validation_timeout)
123+
124+
passed_files, failed_files_with_msg = [], []
125+
126+
def validate_task(filepath):
127+
success, msg = validator.validate_file(filepath)
128+
return filepath, success, msg
129+
130+
num_workers = max(1, os.cpu_count() // 2) if os.cpu_count() else 4
131+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
132+
future_to_filepath = {
133+
executor.submit(validate_task, filepath): filepath
134+
for filepath in saved_files
135+
}
136+
for future in tqdm(
137+
concurrent.futures.as_completed(future_to_filepath),
138+
total=len(saved_files),
139+
desc="Validating",
140+
disable=not self.accelerator.is_main_process
141+
):
142+
filepath, success, msg = future.result()
143+
if success:
144+
passed_files.append(filepath)
145+
else:
146+
failed_files_with_msg.append((filepath, msg))
147+
148+
self.accelerator.print(f"Validation: Passed={len(passed_files)}, Failed={len(failed_files_with_msg)}")
149+
150+
results_map = {f: {"status": "Passed", "log": ""} for f in passed_files}
151+
results_map.update({f: {"status": "Failed", "log": msg} for f, msg in failed_files_with_msg})
152+
153+
return results_map
154+
155+
156+
def run(self):
157+
"""执行完整的推理和验证流程。"""
158+
start_time = time()
159+
160+
# 1. 设置
161+
hf_config = self._setup_hf_config()
162+
eval_dataloader = self._setup_dataloader()
163+
164+
# 2. 加载模型
165+
with ModelRegistry.create("huggingface", **hf_config.model_dump()) as hf_model:
166+
if hf_model.tokenizer.pad_token_id is None:
167+
hf_model.tokenizer.pad_token_id = hf_model.tokenizer.eos_token_id
168+
hf_model.model.config.pad_token_id = hf_model.tokenizer.eos_token_id
169+
170+
# 3. Accelerator Prepare
171+
prepared_model, prepared_dataloader = self.accelerator.prepare(
172+
hf_model.model, eval_dataloader
173+
)
174+
hf_model.model = prepared_model # 更新包装器中的模型引用
175+
176+
# 4. 推理
177+
current_process_outputs = []
178+
if hasattr(hf_model.model, 'eval'):
179+
hf_model.model.eval()
180+
181+
progress_bar = tqdm(
182+
prepared_dataloader,
183+
desc=f"Inference (Proc {self.accelerator.process_index})",
184+
disable=not self.accelerator.is_local_main_process
185+
)
186+
187+
with torch.no_grad():
188+
for batch_data in progress_bar:
189+
prompts = batch_data["prompts_for_model"]
190+
metadata = batch_data["original_items_metadata"]
191+
192+
generated_batch_texts = hf_model.batch_predict(
193+
prompts,
194+
batch_size=len(prompts)
195+
)
196+
197+
for i,gen_text in enumerate(generated_batch_texts):
198+
current_process_outputs.append({
199+
"id":metadata[i]["id"],
200+
"generated_proof_part":gen_text.strip(),
201+
"prompt_ready_stmt":metadata[i]["prompt_ready_stmt"],
202+
"imports_txt":metadata[i]["imports_txt"],
203+
"process_index":self.accelerator.process_index
204+
})
205+
206+
self.accelerator.wait_for_everyone()
207+
208+
# 5. 收集与保存
209+
gathered_outputs = accelerate.utils.gather_object(current_process_outputs)
210+
saved_files_map = {} # id -> Path
211+
if self.accelerator.is_main_process:
212+
print(f"Gathered {len(gathered_outputs)} results. Saving proofs...")
213+
for output_data in tqdm(gathered_outputs, desc="Saving proofs"):
214+
item_id = output_data["id"]
215+
proof_part = output_data["generated_proof_part"]
216+
stmt = output_data["prompt_ready_stmt"]
217+
imports = output_data["imports_txt"]
218+
# 尝试提取 ```lean ... ```, 否则用原始输出
219+
extracted_code = extract_lean_block(proof_part) or proof_part
220+
full_code = f"{imports}\n{extracted_code}"
221+
222+
safe_id = str(item_id).replace("/", "_").replace("\\", "_")
223+
proof_file = self.proof_save_dir / f"{safe_id}.lean"
224+
proof_file.write_text(full_code, encoding="utf-8")
225+
saved_files_map[item_id] = proof_file
226+
227+
# 6. 验证 (仅主进程)
228+
validation_results = self._run_validation(list(saved_files_map.values()))
229+
230+
# 7. 组合最终结果
231+
final_results = []
232+
for res in gathered_outputs:
233+
item_id = res['id']
234+
path = saved_files_map.get(item_id)
235+
val_res = validation_results.get(path, {"status": "Unknown", "log": "File not found in validation map"})
236+
res.update({
237+
"proof_file": str(path),
238+
"validation_status": val_res["status"],
239+
"validation_log": val_res["log"]
240+
})
241+
final_results.append(res)
242+
243+
# 8. 保存 JSON 结果
244+
with open(self.results_log_file, "w", encoding="utf-8") as f:
245+
json.dump(final_results, f, indent=2, ensure_ascii=False)
246+
247+
print(f"Runner finished. Total time: {time() - start_time:.2f}s. Results saved to {self.results_log_file}")
248+
249+
self.accelerator.wait_for_everyone()

0 commit comments

Comments
 (0)