-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_experiment.py
45 lines (34 loc) · 1.47 KB
/
run_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import argparse
import yaml
from environment import create_environment
from agent import setup_agents
from visualization import DataVisualization
if __name__ == '__main__':
# parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--load', '-l', default=False)
parser.add_argument('--config', '-c', default='mnist')
args = parser.parse_args()
# obtain experiment data
if not args.load:
# set up experiment
configs = yaml.safe_load(open(f'./configs/{args.config}.yml'))
print(f'\nRunning experiment {configs["exp_name"]}...')
env = create_environment(configs['env_name'], configs['seed'])
agents = setup_agents(env, configs['params_exp'], configs['params_agents'], configs['seed'])
# train agents
for name, agent in agents.items():
print(f'\nTraining {name} agent on {env.env_name} environment:')
agent.train()
viz = DataVisualization(env, agents, configs)
viz.serialize_data()
else:
# load experiment data
print(f'\nLoading experiment {args.load}...')
configs = yaml.safe_load(open(f'./data/{args.load}.yml'))
env = create_environment(configs['env_name'], configs['seed'])
agents = setup_agents(env, configs['params_exp'], configs['params_agents'], configs['seed'])
viz = DataVisualization(env, agents, configs)
viz.load_data(args.load)
# visualize agents' behavior
viz.visualize_agents()