forked from salesforce/CodeRL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_one_solution.py
118 lines (92 loc) · 3.61 KB
/
test_one_solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
import glob
import json
import os
import os.path
import pickle as pkl
import traceback
import numpy as np
from tqdm import tqdm
from utils.testing_util import run_test
def eval_and_save_problems(args):
problems = sorted(glob.glob(args.test_path + '/*'))
test_indices = []
for problem_idx, problem in enumerate(problems):
problem_id = int(problem.split('/')[-1])
code_file_path = os.path.join(args.code_path, '{}.json'.format(problem_id))
if os.path.exists(code_file_path):
test_indices.append(problem_idx)
real_index = test_indices[args.index]
problem = problems[real_index]
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
print('Testing sample {}'.format(problem))
if args.example_tests:
print("Using example tests")
codes_loc = os.path.join(args.code_path, '{}.json'.format(real_index))
if not os.path.isfile(codes_loc):
print(f"{codes_loc} does not exist")
exit()
with open(codes_loc, "r") as file:
gen_codes = json.load(file)[str(real_index)]['code']
test_file = os.path.join(problem, "input_output.json")
with open(test_file, 'r') as file:
tests = json.load(file)
nb_tests = len(tests['inputs'])
if args.max_tests != -1 and nb_tests > args.max_tests:
print(f"{test_file} contains more tests than --max_tests")
exit()
outputs_loc = os.path.join(args.output_path, '{}.pkl'.format(real_index))
if os.path.isfile(outputs_loc):
print(f"{outputs_loc} already exists")
exit()
print("Saving to {}".format(outputs_loc))
all_results, all_errors, all_sols = [], [], []
for o_idx, o in tqdm(enumerate(gen_codes), total=len(gen_codes), ncols=0, leave=False):
curr_results = []
curr_errors = []
curr_sol = None
try:
curr_results, curr_errors, _, curr_sol = run_test(prob_path=problem, test=o, debug=args.debug,
example_tests=args.example_tests)
curr_errors = [(e, traceback.format_tb(e.__traceback__)) if e is not None else e for e in curr_errors]
fixed = []
for e in curr_results:
if isinstance(e, np.ndarray):
e = e.item(0)
if isinstance(e, np.bool_):
e = bool(e)
fixed.append(e)
curr_results = fixed
except Exception as e:
print(f"test framework exception = {repr(e)}{e}\n")
break
finally:
assert isinstance(curr_results, list)
all_results.append(curr_results)
all_errors.append(curr_errors)
all_sols.append(curr_sol)
save_results = {real_index: {'results': all_results, 'errors': all_errors, 'sols': all_sols}}
with open(outputs_loc, "wb") as file:
pkl.dump(save_results, file)
'''
How to read results:
[-2] = compile error,
[-1] = runtime error
[False] = failed test case
[True] = passed test case
'''
save_results = {real_index: {'results': all_results, 'errors': all_errors, 'sols': all_sols}}
with open(outputs_loc, "wb") as file:
pkl.dump(save_results, file)
def main(args):
# argsdict = vars(args)
eval_and_save_problems(args)
if __name__ == "__main__":
from configs.unit_test_configs import *
main(args)