Skip to content

Commit

Permalink
Merge pull request #9 from samholt:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718086218
Change-Id: I89e1489f67ae634e87ea082583841bee86fdb1c8
  • Loading branch information
copybara-github committed Jan 21, 2025
2 parents 2034ac0 + 0c6812c commit cb7e9b7
Show file tree
Hide file tree
Showing 22 changed files with 154 additions and 120 deletions.
26 changes: 12 additions & 14 deletions learning/train_jax_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,12 @@
# ==============================================================================
"""Train a PPO agent using JAX on the specified environment."""

import os

xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_gpu_triton_gemm_any=True"
os.environ["XLA_FLAGS"] = xla_flags
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"

from datetime import datetime
import functools
import json
import os
import time
import warnings

from absl import app
from absl import flags
Expand All @@ -39,7 +33,6 @@
import jax.numpy as jp
import mediapy as media
from ml_collections import config_dict
from ml_collections import config_flags
import mujoco
from orbax import checkpoint as ocp
from tensorboardX import SummaryWriter
Expand All @@ -52,11 +45,16 @@
from mujoco_playground.config import locomotion_params
from mujoco_playground.config import manipulation_params

xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_gpu_triton_gemm_any=True"
os.environ["XLA_FLAGS"] = xla_flags
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"

# Ignore the info logs from brax
logging.set_verbosity(logging.WARNING)

# Suppress warnings
import warnings

# Suppress RuntimeWarnings from JAX
warnings.filterwarnings("ignore", category=RuntimeWarning, module="jax")
Expand Down Expand Up @@ -267,11 +265,11 @@ def main(argv):
print(f"Checkpoint path: {ckpt_path}")

# Save environment configuration
with open(ckpt_path / "config.json", "w") as fp:
with open(ckpt_path / "config.json", "w", encoding="utf-8") as fp:
json.dump(env_cfg.to_json(), fp, indent=4)

# Define policy parameters function for saving checkpoints
def policy_params_fn(current_step, make_policy, params):
def policy_params_fn(current_step, make_policy, params): # pylint: disable=unused-argument
orbax_checkpointer = ocp.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(params)
path = ckpt_path / f"{current_step}"
Expand Down Expand Up @@ -352,7 +350,7 @@ def progress(num_steps, metrics):
)

# Train or load the model
make_inference_fn, params, _ = train_fn(
make_inference_fn, params, _ = train_fn( # pylint: disable=no-value-for-parameter
environment=env,
progress_fn=progress,
eval_env=None if _VISION.value else eval_env,
Expand Down Expand Up @@ -389,7 +387,7 @@ def progress(num_steps, metrics):
rollout = [state0]

# Run evaluation rollout
for i in range(env_cfg.episode_length):
for _ in range(env_cfg.episode_length):
act_rng, rng = jax.random.split(rng)
ctrl, _ = jit_inference_fn(state.obs, act_rng)
state = jit_step(state, ctrl)
Expand Down
15 changes: 10 additions & 5 deletions learning/train_rsl_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=wrong-import-position
"""Train a PPO agent using RSL-RL for the specified environment."""

import os
Expand All @@ -28,7 +29,6 @@
from absl import flags
from absl import logging
import jax
import jax.numpy as jp
import mediapy as media
from ml_collections import config_dict
import mujoco
Expand Down Expand Up @@ -136,7 +136,9 @@ def main(argv):
wandb.config.update({"env_name": _ENV_NAME.value})

# Save environment config to JSON
with open(os.path.join(ckpt_path, "config.json"), "w") as fp:
with open(
os.path.join(ckpt_path, "config.json"), "w", encoding="utf-8"
) as fp:
json.dump(env_cfg.to_json(), fp, indent=4)

# Domain randomization
Expand All @@ -146,7 +148,7 @@ def main(argv):
render_trajectory = []

# Callback to gather states for rendering
def render_callback(env, state):
def render_callback(_, state):
render_trajectory.append(state)

# Create the environment
Expand Down Expand Up @@ -231,8 +233,11 @@ def render_callback(env, state):
fps = 1.0 / base_env.dt / render_every
traj = rollout[::render_every]
frames = eval_env.render(
traj, camera=_CAMERA.value, height=480, width=640,
scene_option=scene_option
traj,
camera=_CAMERA.value,
height=480,
width=640,
scene_option=scene_option,
)
media.write_video("rollout.mp4", frames, fps=fps)
print("Rollout video saved as 'rollout.mp4'.")
Expand Down
1 change: 1 addition & 0 deletions mujoco_playground/_src/locomotion/g1/randomize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for randomization."""
import jax
from mujoco import mjx

Expand Down
1 change: 1 addition & 0 deletions mujoco_playground/_src/locomotion/locomotion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@


class TestSuite(parameterized.TestCase):
"""Tests for the locomotion environments."""

@parameterized.named_parameters(
{"testcase_name": f"test_can_create_{env_name}", "env_name": env_name}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import jax
import jax.numpy as jp
from ml_collections import config_dict
import mujoco
from mujoco import mjx
import mujoco # pylint: disable=unused-import
from mujoco.mjx._src import math

from mujoco_playground._src import collision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class PandaPickCubeCartesian(pick.PandaPickCube):
"""Environment for training the Franka Panda robot to pick up a cube in
Cartesian space."""

def __init__(
def __init__( # pylint: disable=non-parent-init-called,super-init-not-called
self,
config=default_config(),
config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def domain_randomize(
) -> Tuple[mjx.Model, mjx.Model]:
"""Tile the necessary axes for the Madrona BatchRenderer."""
mj_model = pick_cartesian.PandaPickCubeCartesian().mj_model
FLOOR_GEOM_ID = mj_model.geom('floor').id
BOX_GEOM_ID = mj_model.geom('box').id
STRIP_GEOM_ID = mj_model.geom('init_space').id
floor_geom_id = mj_model.geom('floor').id
box_geom_id = mj_model.geom('box').id
strip_geom_id = mj_model.geom('init_space').id

in_axes = jax.tree_util.tree_map(lambda x: None, mjx_model)
in_axes = in_axes.tree_replace({
Expand Down Expand Up @@ -93,16 +93,16 @@ def rand(rng: jax.Array, light_position: jax.Array):
rgba = jp.array(
[jax.random.uniform(key_box, (), minval=0.5, maxval=1.0), 0.0, 0.0, 1.0]
)
geom_rgba = mjx_model.geom_rgba.at[BOX_GEOM_ID].set(rgba)
geom_rgba = mjx_model.geom_rgba.at[box_geom_id].set(rgba)

strip_white = jax.random.uniform(key_strip, (), minval=0.8, maxval=1.0)
geom_rgba = geom_rgba.at[STRIP_GEOM_ID].set(
geom_rgba = geom_rgba.at[strip_geom_id].set(
jp.array([strip_white, strip_white, strip_white, 1.0])
)

# Sample a shade of gray
gray_scale = jax.random.uniform(key_floor, (), minval=0.0, maxval=0.25)
geom_rgba = geom_rgba.at[FLOOR_GEOM_ID].set(
geom_rgba = geom_rgba.at[floor_geom_id].set(
jp.array([gray_scale, gray_scale, gray_scale, 1.0])
)

Expand All @@ -112,11 +112,11 @@ def rand(rng: jax.Array, light_position: jax.Array):
jax.random.randint(key_matid, shape=(num_geoms,), minval=0, maxval=10)
+ mat_offset
)
geom_matid = geom_matid.at[BOX_GEOM_ID].set(
geom_matid = geom_matid.at[box_geom_id].set(
-2
) # Use the above randomized colors
geom_matid = geom_matid.at[FLOOR_GEOM_ID].set(-2)
geom_matid = geom_matid.at[STRIP_GEOM_ID].set(-2)
geom_matid = geom_matid.at[floor_geom_id].set(-2)
geom_matid = geom_matid.at[strip_geom_id].set(-2)

#### Cameras ####
key_pos, key_ori, key = jax.random.split(key, 3)
Expand All @@ -134,7 +134,7 @@ def rand(rng: jax.Array, light_position: jax.Array):
assert (
nlight == 1
), f'Sim2Real was trained with a single light source, got {nlight}'
key_lsha, key_ldir, key_ldct, key = jax.random.split(key, 4)
key_lsha, key_ldir, key = jax.random.split(key, 3)

# Direction
shine_at = jp.array([0.661, -0.001, 0.179]) # Gripper starting position
Expand Down
3 changes: 1 addition & 2 deletions mujoco_playground/_src/manipulation/leap_hand/rotate_z.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import numpy as np

from mujoco_playground._src import mjx_env
from mujoco_playground._src import reward
from mujoco_playground._src.manipulation.leap_hand import base as leap_hand_base
from mujoco_playground._src.manipulation.leap_hand import leap_hand_constants as consts

Expand Down Expand Up @@ -145,7 +144,7 @@ def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
rewards = {
k: v * self._config.reward_config.scales[k] for k, v in rewards.items()
}
reward = sum(rewards.values()) * self.dt
reward = sum(rewards.values()) * self.dt # pylint: disable=redefined-outer-name

state.info["last_last_act"] = state.info["last_act"]
state.info["last_act"] = action
Expand Down
1 change: 1 addition & 0 deletions mujoco_playground/_src/manipulation/manipulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@


class TestSuite(parameterized.TestCase):
"""Tests for the manipulation environments."""

@parameterized.named_parameters(
{"testcase_name": f"test_can_create_{env_name}", "env_name": env_name}
Expand Down
1 change: 0 additions & 1 deletion mujoco_playground/_src/mjx_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import numpy as np
import tqdm


# Root path is used for loading XML strings directly using etils.epath.
ROOT_PATH = epath.Path(__file__).parent
# Base directory for external dependencies.
Expand Down
Loading

0 comments on commit cb7e9b7

Please sign in to comment.