|
import json |
|
import os |
|
import tqdm |
|
input_list = [ |
|
"compile_result/gsm8k_train/lean4_random_15kpass10.jsonl", |
|
"compile_result/math_train/lean4_random_15kpass10.jsonl", |
|
] |
|
def get_statement_proof(text): |
|
import re |
|
|
|
|
|
|
|
statement_pattern = r"statement:\n(.*?)(?=\n\nproof:)" |
|
proof_pattern = r"proof:\n(.*)" |
|
|
|
statement_match = re.search(statement_pattern, text, re.DOTALL) |
|
proof_match = re.search(proof_pattern, text, re.DOTALL) |
|
|
|
statement_content = statement_match.group(1).strip() |
|
proof_content = proof_match.group(1).strip() |
|
|
|
return statement_content, proof_content |
|
|
|
def save_passed_results(input_list): |
|
for input_file in input_list: |
|
save_dir = os.path.dirname(input_file) |
|
save_file = os.path.join(save_dir, 'pass_for_train.jsonl') |
|
|
|
with open(input_file, 'r') as file: |
|
data = json.load(file) |
|
with open(save_file, 'w') as output_file: |
|
for item in tqdm.tqdm(data['results']): |
|
statement, proof = get_statement_proof(item['question']) |
|
|
|
|
|
output_set = set() |
|
dedup_outputs = [] |
|
dedup_results = [] |
|
|
|
for output, result in zip(item['total output'], item['results']): |
|
if output not in output_set: |
|
output_set.add(output) |
|
dedup_outputs.append(output) |
|
dedup_results.append(result) |
|
|
|
for id in range(len(dedup_outputs)): |
|
id_result = dedup_results[id] |
|
id_output = dedup_outputs[id] |
|
if id_result.get("status") == 'pass': |
|
result_dict = { |
|
'nl_problem': statement, |
|
'nl_proof': proof, |
|
'formal': id_output |
|
} |
|
output_file.write(json.dumps(result_dict) + '\n') |
|
|
|
|
|
input_list = [ |
|
"compile_result/gsm8k_train/lean4_random_15kpass10.jsonl", |
|
"compile_result/math_train/lean4_random_15kpass10.jsonl", |
|
] |
|
|
|
save_passed_results(input_list) |
|
|