Skip to content

Commit

Permalink
Fix test.py to show proper verification with task policy
Browse files Browse the repository at this point in the history
  • Loading branch information
cr-xu committed Feb 12, 2024
1 parent 8d97bb4 commit 62a2424
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions meta-rl/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
import numpy as np
import torch
import yaml
from tqdm import trange

from maml_rl.baseline import LinearFeatureBaseline
from maml_rl.envs.awake_steering_simulated import AwakeSteering as awake_env
from maml_rl.samplers import MultiTaskSampler
from maml_rl.utils.helpers import get_input_size, get_policy_for_env
from maml_rl.utils.reinforcement_learning import get_episode_lengths, get_returns
from policy_test import _layout_verficication_plot, verify
from tqdm import trange


def save_progress(file_name, data, save_progress_data_dir):
Expand Down Expand Up @@ -98,11 +97,8 @@ def main(args):
policy.load_state_dict(state_dict)
use_task_policy = logging_path
meta_policy_location = args.policy
elif config["use_task_policy"]:
use_task_policy = logging_path
else:
use_task_policy = False
print("use_task_policy", use_task_policy)
use_task_policy = logging_path
policy.share_memory()

# Baseline
Expand Down Expand Up @@ -188,9 +184,6 @@ def main(args):
action="store_true",
help="use the pre-trained meta-policy",
)
parser.add_argument(
"--use-task-policy", action="store_true", help="use the pre-trained task-policy"
)
parser.add_argument(
"--evaluation-tasks",
type=str,
Expand Down

0 comments on commit 62a2424

Please sign in to comment.