Skip to content

Commit

Permalink
chore: clean some trivial code (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
muchvo authored May 4, 2023
1 parent b821adf commit e0b1852
Show file tree
Hide file tree
Showing 13 changed files with 136 additions and 252 deletions.
19 changes: 10 additions & 9 deletions examples/analyze_experiment_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@


# just fill in the path in which experiment grid runs.
path = ''
st = StatisticsTools()
st.load_source(path)
# just fill in the name of the parameter of which value you want to compare.
# then you can specify the value of the parameter you want to compare,
# or you can just specify how many values you want to compare in single graph at most,
# and the function will automatically generate all possible combinations of the graph.
# but the two mode can not be used at the same time.
st.draw_graph(parameter='', values=None, compare_num=2, cost_limit=None)
PATH = ''
if __name__ == '__main__':
st = StatisticsTools()
st.load_source(PATH)
# just fill in the name of the parameter of which value you want to compare.
# then you can specify the value of the parameter you want to compare,
# or you can just specify how many values you want to compare in single graph at most,
# and the function will automatically generate all possible combinations of the graph.
# but the two mode can not be used at the same time.
st.draw_graph(parameter='', values=None, compare_num=2, cost_limit=None, show_image=True)
49 changes: 2 additions & 47 deletions examples/benchmarks/run_experiment_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,61 +14,16 @@
# ==============================================================================
"""Example of training a policy from exp-x config with OmniSafe."""

from __future__ import annotations

import os
import sys
import warnings
from typing import Any

import torch

import omnisafe
from omnisafe.common.experiment_grid import ExperimentGrid


def train(
exp_id: str,
algo: str,
env_id: str,
custom_cfgs: dict[str, Any],
) -> tuple[float, float, int]:
"""Train a policy from exp-x config with OmniSafe.
Args:
exp_id (str): Experiment ID.
algo (str): Algorithm to train.
env_id (str): The name of test environment.
custom_cfgs (dict): Custom configurations.
"""
terminal_log_name = 'terminal.log'
error_log_name = 'error.log'
if 'seed' in custom_cfgs:
terminal_log_name = f'seed{custom_cfgs["seed"]}_{terminal_log_name}'
error_log_name = f'seed{custom_cfgs["seed"]}_{error_log_name}'
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
print(f'exp-x: {exp_id} is training...')
if not os.path.exists(custom_cfgs['logger_cfgs']['log_dir']):
os.makedirs(custom_cfgs['logger_cfgs']['log_dir'], exist_ok=True)
# pylint: disable-next=consider-using-with
sys.stdout = open( # noqa: SIM115
os.path.join(f'{custom_cfgs["logger_cfgs"]["log_dir"]}', terminal_log_name),
'w',
encoding='utf-8',
)
# pylint: disable-next=consider-using-with
sys.stderr = open( # noqa: SIM115
os.path.join(f'{custom_cfgs["logger_cfgs"]["log_dir"]}', error_log_name),
'w',
encoding='utf-8',
)
agent = omnisafe.Agent(algo, env_id, custom_cfgs=custom_cfgs)
reward, cost, ep_len = agent.learn()
return reward, cost, ep_len
from omnisafe.utils.exp_grid_tools import train


if __name__ == '__main__':
eg = ExperimentGrid(exp_name='Safety_Gymnasium_Goal')
eg = ExperimentGrid(exp_name='Benchmark_Safety_Velocity')

# Set the algorithms.
base_policy = ['PolicyGradient', 'NaturalPG', 'TRPO', 'PPO']
Expand Down
48 changes: 24 additions & 24 deletions examples/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,28 @@
# python plot.py --logdir omnisafe/examples/runs/PPOLag-{SafetyAntVelocity-v1}
# after training the policy with the following command:
# python train_policy.py --algo PPOLag --env-id SafetyAntVelocity-v1
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--logdir', nargs='*')
parser.add_argument('--legend', '-l', nargs='*')
parser.add_argument('--xaxis', '-x', default='Steps')
parser.add_argument('--value', '-y', default='Rewards', nargs='*')
parser.add_argument('--count', action='store_true')
parser.add_argument('--smooth', '-s', type=int, default=1)
parser.add_argument('--select', nargs='*')
parser.add_argument('--exclude', nargs='*')
parser.add_argument('--estimator', default='mean')
args = parser.parse_args()

parser = argparse.ArgumentParser()
parser.add_argument('--logdir', nargs='*')
parser.add_argument('--legend', '-l', nargs='*')
parser.add_argument('--xaxis', '-x', default='Steps')
parser.add_argument('--value', '-y', default='Rewards', nargs='*')
parser.add_argument('--count', action='store_true')
parser.add_argument('--smooth', '-s', type=int, default=1)
parser.add_argument('--select', nargs='*')
parser.add_argument('--exclude', nargs='*')
parser.add_argument('--estimator', default='mean')
args = parser.parse_args()

plotter = Plotter()
plotter.make_plots(
args.logdir,
args.legend,
args.xaxis,
args.value,
args.count,
smooth=args.smooth,
select=args.select,
exclude=args.exclude,
estimator=args.estimator,
)
plotter = Plotter()
plotter.make_plots(
args.logdir,
args.legend,
args.xaxis,
args.value,
args.count,
smooth=args.smooth,
select=args.select,
exclude=args.exclude,
estimator=args.estimator,
)
41 changes: 21 additions & 20 deletions examples/train_from_custom_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,26 @@
import omnisafe


env_id = 'SafetyPointGoal1-v0'
custom_cfgs = {
'train_cfgs': {
'total_steps': 1024000,
'vector_env_nums': 1,
'parallel': 1,
},
'algo_cfgs': {
'steps_per_epoch': 2048,
'update_iters': 1,
},
'logger_cfgs': {
'use_wandb': False,
},
}
if __name__ == '__main__':
env_id = 'SafetyPointGoal1-v0'
custom_cfgs = {
'train_cfgs': {
'total_steps': 1024000,
'vector_env_nums': 1,
'parallel': 1,
},
'algo_cfgs': {
'steps_per_epoch': 2048,
'update_iters': 1,
},
'logger_cfgs': {
'use_wandb': False,
},
}

agent = omnisafe.Agent('PPOLag', env_id, custom_cfgs=custom_cfgs)
agent.learn()
agent = omnisafe.Agent('PPOLag', env_id, custom_cfgs=custom_cfgs)
agent.learn()

agent.plot(smooth=1)
agent.render(num_episodes=1, render_mode='rgb_array', width=256, height=256)
agent.evaluate(num_episodes=1)
agent.plot(smooth=1)
agent.render(num_episodes=1, render_mode='rgb_array', width=256, height=256)
agent.evaluate(num_episodes=1)
13 changes: 7 additions & 6 deletions examples/train_from_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
import omnisafe


env_id = 'SafetyPointGoal1-v0'
if __name__ == '__main__':
env_id = 'SafetyPointGoal1-v0'

agent = omnisafe.Agent('PPOLag', env_id)
agent.learn()
agent = omnisafe.Agent('PPOLag', env_id)
agent.learn()

agent.plot(smooth=1)
agent.render(num_episodes=1, render_mode='rgb_array', width=256, height=256)
agent.evaluate(num_episodes=1)
agent.plot(smooth=1)
agent.render(num_episodes=1, render_mode='rgb_array', width=256, height=256)
agent.evaluate(num_episodes=1)
10 changes: 9 additions & 1 deletion omnisafe/common/experiment_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def analyze(
values: list[Any] | None = None,
compare_num: int | None = None,
cost_limit: float | None = None,
show_image: bool = False,
) -> None:
"""Analyze the experiment results.
Expand All @@ -554,10 +555,17 @@ def analyze(
will combine any potential combination to compare. Defaults to None.
cost_limit (float or None, optional): Value for one line showed on graph to indicate
cost. Defaults to None.
show_image (bool): Whether to show graph image in GUI windows.
"""
assert self._statistical_tools is not None, 'Please run run() first!'
self._statistical_tools.load_source(self.log_dir)
self._statistical_tools.draw_graph(parameter, values, compare_num, cost_limit)
self._statistical_tools.draw_graph(
parameter,
values,
compare_num,
cost_limit,
show_image=show_image,
)

def evaluate(self, num_episodes: int = 10, cost_criteria: float = 1.0) -> None:
"""Agent Evaluation.
Expand Down
6 changes: 5 additions & 1 deletion omnisafe/common/statistics_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def draw_graph(
compare_num: int | None = None,
cost_limit: float | None = None,
smooth: int = 1,
show_image: bool = False,
) -> None:
"""Draw graph.
Expand All @@ -100,6 +101,7 @@ def draw_graph(
compare_num (int or None, optional): The number of values to compare. Defaults to None.
cost_limit (float or None, optional): The cost limit of the experiment. Defaults to None.
smooth (int, optional): The smooth window size. Defaults to 1.
show_image (bool): Whether to show graph image in GUI windows.
.. note::
`values` and `compare_num` cannot be set at the same time.
Expand Down Expand Up @@ -158,11 +160,13 @@ def draw_graph(
None,
'mean',
save_name=save_name,
show_image=show_image,
)
except RuntimeError:
except Exception: # noqa # pylint: disable=broad-except
print(
f'Cannot generate graph for {save_name[:5] + str(decompressed_img_name_cfgs)}',
)
print(Exception)

def make_config_groups(
self,
Expand Down
62 changes: 8 additions & 54 deletions omnisafe/utils/command_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@


import os
import sys
import warnings
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List

import numpy as np
import torch
Expand All @@ -29,6 +28,7 @@
import omnisafe
from omnisafe.common.experiment_grid import ExperimentGrid
from omnisafe.common.statistics_tools import StatisticsTools
from omnisafe.utils.exp_grid_tools import train as train_grid
from omnisafe.utils.tools import assert_with_exit, custom_cfgs_to_dict, update_dict


Expand Down Expand Up @@ -159,58 +159,6 @@ def train( # pylint: disable=too-many-arguments
console.print('failed to evaluate model', style='red bold')


def train_grid(
exp_id: str,
algo: str,
env_id: str,
custom_cfgs: Dict[str, Any],
) -> Tuple[float, float, int]:
r"""Train a policy from exp-x config with OmniSafe.
Examples:
.. code-block:: bash
python -m omnisafe train_grid --exp_id exp-1 --algo PPOLag --env_id SafetyPointGoal1-v0 \
--parallel 1 --total_steps 1000000 --device cpu --vector_env_nums 1
Args:
exp_id (str): Experiment ID.
algo (str): Algorithm to train.
env_id (str): The name of test environment.
custom_cfgs (dict[str, Any]): Custom configuration for training.
Returns:
ep_ret: Average episode return in final epoch.
ep_cost: Average episode cost in final epoch.
ep_len: Average episode length in final epoch.
"""
terminal_log_name = 'terminal.log'
error_log_name = 'error.log'
if 'seed' in custom_cfgs:
terminal_log_name = f'seed{custom_cfgs["seed"]}_{terminal_log_name}'
error_log_name = f'seed{custom_cfgs["seed"]}_{error_log_name}'
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
print(f'exp-x: {exp_id} is training...')
if not os.path.exists(custom_cfgs['logger_cfgs']['log_dir']):
os.makedirs(custom_cfgs['logger_cfgs']['log_dir'])
# pylint: disable-next=consider-using-with
sys.stdout = open( # noqa: SIM115
os.path.join(f'{custom_cfgs["logger_cfgs"]["log_dir"]}', terminal_log_name),
'w',
encoding='utf-8',
)
# pylint: disable-next=consider-using-with
sys.stderr = open( # noqa: SIM115
os.path.join(f'{custom_cfgs["logger_cfgs"]["log_dir"]}', error_log_name),
'w',
encoding='utf-8',
)
agent = omnisafe.Agent(algo, env_id, custom_cfgs=custom_cfgs)
ep_ret, ep_cost, ep_len = agent.learn()
return ep_ret, ep_cost, ep_len


@app.command()
def benchmark(
exp_name: str = typer.Argument(
Expand Down Expand Up @@ -467,6 +415,10 @@ def analyze_grid(
None,
help='the cost limit to show in graphs by a single line',
),
show_image: bool = typer.Option(
False,
help='whether to show the images in GUI windows',
),
) -> None:
"""Statistics tools for experiment grid.
Expand All @@ -481,6 +433,7 @@ def analyze_grid(
compare_num (int): Number of values to compare, if it is specified, will combine any
potential combination to compare
cost_limit (float): The cost limit.
show_image (bool): Whether to show the images in GUI windows.
"""
tools = StatisticsTools()
tools.load_source(path)
Expand All @@ -490,6 +443,7 @@ def analyze_grid(
values=None,
compare_num=compare_num,
cost_limit=cost_limit,
show_image=show_image,
)


Expand Down
Loading

0 comments on commit e0b1852

Please sign in to comment.