This repository has been archived by the owner on Sep 12, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
216 lines (171 loc) · 6.81 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""
Fork of https://github.com/astier/model-free-episodic-control (MIT-Licensed)
"""
import random
import argparse
import time
import matplotlib.pyplot as plt
import os
import shutil
import gym
from mfec.agent import MFECAgent
from mfec.utils import Utils
parser = argparse.ArgumentParser()
parser.add_argument('--knn', type=int, default=9,
help="Number of nearest neighbors taken into account (default: 9)")
parser.add_argument('--dim', type=int, default=32,
help="Number of random features extracted by the OPU (default: 32)")
parser.add_argument('--discount', type=float, default=0.99,
help="Discount rate for the return (default: 0.99)")
parser.add_argument('--epsilon', type=float, default=0.001,
help="Percentage of random exploration (default: 0.001)")
parser.add_argument('--env', type=str, default="MsPacman-v0",
help="Game to play (more games at https://gym.openai.com/envs/#atari)")
parser.add_argument('-v', '--volatile', action="store_true",
help="Prevent frames from being recorded")
parser.add_argument('-s', '--save', action="store_true",
help="Save the trained agent for later use")
args = parser.parse_args()
ENVIRONMENT = args.env
RENDER = not args.volatile
SAVE = args.save
EPOCHS = 4
FRAMES_PER_EPOCH = 100000
ACTION_BUFFER_SIZE = 200000 # Number of states that can be stored for each action
K = args.knn
DISCOUNT = args.discount
EPSILON = args.epsilon
FRAMESKIP = 3 # Default gym-setting is (2, 5), see notes in the README
REPEAT_ACTION_PROB = 0.0 # Default gym-setting is .25
SCALE_DIMS = None #(58, 40) # Dimensions to rescale the inputs to, None means no rescaling
STATE_DIMENSION = args.dim
def main():
"""Learns to play ENVIRONMENT. Initializes the environment and the agent.
"""
random.seed(None)
# Creates folder to store some of the frames
try:
shutil.rmtree("videos")
except FileNotFoundError:
pass
os.mkdir("videos")
# Initialize utils, environment and agent
utils = Utils(FRAMES_PER_EPOCH, EPOCHS * FRAMES_PER_EPOCH)
env = gym.make(ENVIRONMENT)
try:
env.frameskip = FRAMESKIP
env.ale.setFloat("repeat_action_probability", REPEAT_ACTION_PROB)
agent = MFECAgent(
ACTION_BUFFER_SIZE,
K,
DISCOUNT,
EPSILON,
SCALE_DIMS,
STATE_DIMENSION,
range(env.action_space.n)
)
run_algorithm(agent, env, utils)
exploit_score = exploit(agent, env)
if SAVE:
import pickle
with open("agent.pkl", 'wb') as file:
pickle.dump(agent, file, 2)
if RENDER: # Creates the video of the best recorded run
# https://askubuntu.com/questions/610903/how-can-i-create-a-video-file-from-a-set-of-jpg-images
if bestrun[2] > exploit_score:
os.system(
'ffmpeg -framerate 25 -i videos/{}-{}-%00000d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p bestrun.mp4'.format(
bestrun[0], bestrun[1]))
else:
os.system(
'ffmpeg -framerate 25 -i videos/exploit-%00000d.png -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p bestrun.mp4')
print('\n\nBest:', max(bestrun[2], exploit_score), '\n')
finally:
utils.close()
env.close()
def run_algorithm(agent, env, utils):
"""Runs the algorithm.
"""
global epi, epo, bestrun # Variables used to store the frames
frames_left = 0
successive_wins = 0
for _ in range(EPOCHS):
frames_left += FRAMES_PER_EPOCH
epi = 1
while frames_left > 0:
# if os.path.exists('terminate_cron.txt'): # Workaround to terminate the process at a certain time
# os.remove('terminate_cron.txt')
# successive_wins = 6
# break
episode_frames, episode_reward = run_episode(agent, env)
if (epi == 1 or epi % 10 == 0) and episode_reward > bestrun[2]:
bestrun = (epo, epi, episode_reward)
epi += 1
frames_left -= episode_frames
utils.end_episode(episode_frames, episode_reward) # Console display
if episode_reward > threshold: # This is a means of terminating the script before all epochs are completed
successive_wins += 1
else:
successive_wins = 0
if successive_wins > 5:
break
if successive_wins > 5:
print('Solved!\n\n')
break
utils.end_epoch() # Console display
epo += 1
def run_episode(agent, env):
"""Finds the right action depending on the observed state of the ENVIRONMENT and sends
it to the ENVIRONMENT until the ENVIRONMENT returns a 'done' signal.
"""
episode_frames = 0
episode_reward = 0
env.seed(random.randint(0, 1000000))
observation = env.reset()
max_lives = env.ale.lives()
done = False
frame = 0 # Used if RENDER is True
while not done: # While not game over
# if RENDER: # Live display
# env.render()
# time.sleep(RENDER_SPEED)
if RENDER and (epi == 1 or epi % 10 == 0):
plt.imsave(os.path.join('videos', '-'.join((str(epo), str(epi), str(frame))) + '.png'),
env.render(mode='rgb_array'))
action = agent.choose_action(observation)
observation, reward, done, info = env.step(action)
if info['ale.lives'] < max_lives: # No revive
done = True
frame += 1
agent.receive_reward(reward)
episode_reward += reward
episode_frames += FRAMESKIP
agent.train()
return episode_frames, episode_reward
def exploit(agent, env):
"""Same as run_episode but EPSILON is considered to be 0, i.e. there is no exploration.
"""
episode_reward = 0
env.seed(random.randint(0, 1000000))
observation = env.reset()
max_lives = env.ale.lives()
frame = 0
done = False
while not done:
if RENDER:
plt.imsave(os.path.join('videos', '-'.join(('exploit', str(frame))) + '.png'), env.render(mode='rgb_array'))
action = agent.choose_action(observation, explore=False)
observation, reward, done, info = env.step(action)
if info['ale.lives'] < max_lives: # No revive
done = True
frame += 1
episode_reward += reward
print('\nExploitation run: score', episode_reward, '\n')
return episode_reward
if __name__ == "__main__":
threshold = 10000 # If the score of an episode reaches this threshold 5 times in a row, the scipts stops.
# Next three variables are used if RENDER is True
epo = 1
epi = 1
bestrun = (1, 1, 0)
main()