-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathbatch_run.py
89 lines (77 loc) · 3.34 KB
/
batch_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import multiprocessing
from multiprocessing import Pool
# from multiprocessing.pool import ThreadPool as Pool
from scripts import configs_cv as C
from scripts import ranges_cv as R
from scripts.rl_exp_cv import main as main_func
import itertools
import copy
from rl.tools.utils.misc_utils import zipsame
# Limit the number threads used by tensorflow
import tensorflow as tf
num_threads=8
tf.config.threading.set_inter_op_parallelism_threads(num_threads)
tf.config.threading.set_intra_op_parallelism_threads(num_threads)
def func(tp):
print(tp['general']['exp_name'], tp['general']['seed'])
def get_valcombs_and_keys(ranges):
keys = []
values = []
for r in ranges:
keys += r[::2]
values = [list(zipsame(*r[1::2])) for r in ranges]
cs = itertools.product(*values)
combs = []
for c in cs:
comb = []
for x in c:
comb += x
print(comb)
combs.append(comb)
return combs, keys
def main(env, configs_name, range_names, base_algorithms, n_processes):
# Set to the number of workers you want (it defaults to the cpu count of your machine)
if n_processes == -1:
n_processes = None
print('# of CPU (threads): {}'.format(multiprocessing.cpu_count()))
configs = getattr(C, 'configs_' + configs_name)
tps = []
for range_name in range_names:
ranges = R.get_ranges(env, range_name, base_algorithms)
combs, keys = get_valcombs_and_keys(ranges)
print('Total number of combinations: {}'.format(len(combs)))
for i, comb in enumerate(combs):
tp = copy.deepcopy(configs)
value_strs = [tp['general']['exp_name']] # the description string start from the the exp name
for (value, key) in zip(comb, keys):
entry = tp
for k in key[:-1]: # walk down the configs tree
entry = entry[k]
# Make sure the key is indeed included in the configs, so that we set the desired flag.
assert key[-1] in entry, 'newly added key: {}'.format(key[-1])
entry[key[-1]] = value
# We do not include seed number.
if len(key) == 2 and key[0] == 'general' and key[1] == 'seed':
continue
else:
if value is True:
value = 'T'
if value is False:
value = 'F'
value_strs.append(str(value).split('/')[0]) # in case of experts/cartpole/final.ckpt....
tp['general']['exp_name'] = '-'.join(value_strs)
tps.append(tp)
with Pool(processes=n_processes, maxtasksperchild=1) as p: # None for using all the cpus available
p.map(main_func, tps, chunksize=1)
# p.map(func, tps, chunksize=1)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Change this to 'cp', 'hopper', snake', 'walker3d', or 'default', to use the stepsize setting for your env.
parser.add_argument('env')
parser.add_argument('configs_name')
parser.add_argument('-r', '--range_names', nargs='+')
parser.add_argument('-a', '--base_algorithms', nargs='+')
parser.add_argument('--n_processes', type=int, default=-1)
args = parser.parse_args()
main(args.env, args.configs_name, args.range_names, args.base_algorithms, args.n_processes)