Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
7ee8348 by Mustafa Haiderbhai <[email protected]>:

Add brax training instructions, fix arg

COPYBARA_INTEGRATE_REVIEW=#31 from StafaH:brax_ppo_fix 7ee8348
PiperOrigin-RevId: 718036613
Change-Id: I775c1a1321c4c474089dfb5a627ad141180df7fd
  • Loading branch information
StafaH authored and copybara-github committed Jan 21, 2025
1 parent ca3de07 commit 7dc538b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
17 changes: 16 additions & 1 deletion learning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@ For more detailed tutorials on using MuJoCo Playground for RL, see:
4. Training CartPole from Vision [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_1.ipynb)
5. Robotic Manipulation from Vision [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_2.ipynb)

## Training with brax PPO

To train with brax PPO, you can use the `train_jax_ppo.py` script. This script uses the brax PPO algorithm to train an agent on a given environment.

```bash
python train_jax_ppo.py --env_name=CartpoleBalance
```

To train a vision-based policy using pixel observations:
```bash
python train_jax_ppo.py --env_name=CartpoleBalance --vision
```

Use `python train_jax_ppo.py --help` to see possible options and usage. Logs and checkpoints are saved in `logs` directory.

## Training with RSL-RL

To train with RSL-RL, you can use the `train_rsl_rl.py` script. This script uses the RSL-RL algorithm to train an agent on a given environment.
Expand All @@ -18,7 +33,7 @@ To train with RSL-RL, you can use the `train_rsl_rl.py` script. This script uses
python train_rsl_rl.py --env_name=LeapCubeReorient
```

to render the behaviour from the resulting policy:
To render the behaviour from the resulting policy:
```bash
python learning/train_rsl_rl.py --env_name LeapCubeReorient --play_only --load_run_name <run_name>
```
Expand Down
8 changes: 4 additions & 4 deletions learning/train_jax_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,12 @@ def main(argv):
if _CLIPPING_EPSILON.present:
ppo_params.clipping_epsilon = _CLIPPING_EPSILON.value
if _POLICY_HIDDEN_LAYER_SIZES.present:
ppo_params.network_factory.policy_hidden_layer_sizes = tuple(
_POLICY_HIDDEN_LAYER_SIZES.value
ppo_params.network_factory.policy_hidden_layer_sizes = list(
map(int, _POLICY_HIDDEN_LAYER_SIZES.value)
)
if _VALUE_HIDDEN_LAYER_SIZES.present:
ppo_params.network_factory.value_hidden_layer_sizes = tuple(
_VALUE_HIDDEN_LAYER_SIZES.value
ppo_params.network_factory.value_hidden_layer_sizes = list(
map(int, _VALUE_HIDDEN_LAYER_SIZES.value)
)
if _POLICY_OBS_KEY.present:
ppo_params.network_factory.policy_obs_key = _POLICY_OBS_KEY.value
Expand Down

0 comments on commit 7dc538b

Please sign in to comment.