Skip to content

Commit

Permalink
Update ppo script to reproduce the tutorial result
Browse files Browse the repository at this point in the history
  • Loading branch information
cr-xu committed Feb 2, 2024
1 parent 2c5deb3 commit 600de2c
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions meta-rl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

import numpy as np
import torch
from stable_baselines3 import PPO

from maml_rl.envs.awake_steering_simulated import AwakeSteering as awake_env
from policy_test import verify_external_policy_on_specific_env
from stable_baselines3 import PPO


def main(args):
Expand Down Expand Up @@ -38,14 +37,17 @@ def main(args):
model = PPO(
"MlpPolicy", env, verbose=1, seed=seed, tensorboard_log="./logs/ppo/"
)
model.learn(total_timesteps=args.steps)
model.set_random_seed(seed)
if args.steps > model.n_steps:
model.learn(total_timesteps=args.steps)
model.save(args.output_file)
else:
print("Loading model...")
model = PPO.load(args.output_file)

def get_deterministic_policy(x):
return model.predict(x, deterministic=True)[0]
return model.predict(x)[0]
# return model.action_space.sample()

policy = get_deterministic_policy

Expand Down

0 comments on commit 600de2c

Please sign in to comment.