Skip to content

Commit

Permalink
Make sure opened files are closed
Browse files Browse the repository at this point in the history
  • Loading branch information
CompilerCrash committed Apr 20, 2023
1 parent 2c4b4c6 commit 24769c3
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 9 deletions.
12 changes: 8 additions & 4 deletions datasets/apps_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,17 @@ def initialize(self):
answer_type = "\nUse Standard Input format\n"
starter_code = ""

sols_str_list = json.load(open(sols_fname, 'r'))
with open(sols_fname, 'r') as f:
sols_str_list = json.load(f)
gt_samples = self.load_gt_samples(sols_str_list, answer_type, starter_code, question_str)
all_samples += gt_samples

# Read all the solutions
if self.tuning_mode in ['critic']:
for fname in gen_sols_fname:
if os.path.exists(fname):
gen_sols = json.load(open(fname, 'r'))
with open(fname, 'r') as f:
gen_sols = json.load(f)
samples, info = self.load_gen_samples(gen_sols, answer_type, starter_code, question_str)
self.update_error_stat(info)
gen_samples += samples
Expand All @@ -158,14 +160,16 @@ def initialize(self):
elif self.tuning_mode in ['rl']:

if self.relative_returns:
baseline_sample = json.load(open(baseline_fname, 'r'))
with open(baseline_fname, 'r') as f:
baseline_sample = json.load(f)
baseline_error_type = self.get_baseline_error_type(baseline_sample)
else:
baseline_error_type = -1

for fname in gen_sols_fname:
if os.path.exists(fname):
gen_sols = pkl.load(open(fname, 'rb'))
with open(fname, 'rb') as f:
gen_sols = pkl.load(f)
samples, info = self.load_rl_samples(gen_sols, baseline_error_type)
self.update_error_stat_rl(info)
gen_samples += samples
Expand Down
6 changes: 4 additions & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def generate_critic_inputs(args, test_case_path, prompt_path, solutions_path, to
in_tokens[:len(q_tokens)] = q_tokens
in_tokens = in_tokens[:args.source_len]

solutions = json.load(open(solutions_path, 'r'))
with open(solutions_path, 'r') as f:
solutions = json.load(f)

all_texts = []
gt_errors = []
Expand Down Expand Up @@ -183,7 +184,8 @@ def main(args):
else:
scores_loc = os.path.join(prob_path, "gen_solutions_critic_scores.pkl")

pkl.dump(saved_critic_scores, open(scores_loc, 'wb'))
with open(scores_loc, 'wb') as f:
pkl.dump(saved_critic_scores, f)

else:
input_ids = torch.LongTensor(tokenizer.encode(input_text,
Expand Down
6 changes: 4 additions & 2 deletions test_one_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def eval_and_save_problems(args):
gen_codes = json.load(file)[str(real_index)]['code']

test_file = os.path.join(problem, "input_output.json")
tests = json.load(open(test_file, 'r'))
with open(test_file, 'r') as file:
tests = json.load(file)
nb_tests = len(tests['inputs'])
if args.max_tests != -1 and nb_tests > args.max_tests:
print(f"{test_file} contains more tests than --max_tests")
Expand Down Expand Up @@ -102,7 +103,8 @@ def eval_and_save_problems(args):
'''

save_results = {real_index: {'results': all_results, 'errors': all_errors, 'sols': all_sols}}
pkl.dump(save_results, open(outputs_loc, "wb"))
with open(outputs_loc, "wb") as file:
pkl.dump(save_results, file)


def main(args):
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ def main(args):
train_data = get_dataset(args)

# Save args to file
json.dump(argsdict, open(os.path.join(args.save_dir, "args.json"), 'w'))
with open(os.path.join(args.save_dir, "args.json"), 'w') as file:
json.dump(argsdict, file)

# Load and train model; save model checkpoints
run_training(args, train_data)
Expand Down

0 comments on commit 24769c3

Please sign in to comment.