diff --git a/mujoco_playground/_src/mjx_env.py b/mujoco_playground/_src/mjx_env.py index 96a0225..1a4de43 100644 --- a/mujoco_playground/_src/mjx_env.py +++ b/mujoco_playground/_src/mjx_env.py @@ -148,9 +148,9 @@ def init( if act is not None: data = data.replace(act=act) if mocap_pos is not None: - data = data.replace(mocap_pos=mocap_pos.reshape(len(data.mocap_pos), -1)) + data = data.replace(mocap_pos=mocap_pos.reshape(model.nmocap, -1)) if mocap_quat is not None: - data = data.replace(mocap_quat=mocap_quat.reshape(len(data.mocap_quat), -1)) + data = data.replace(mocap_quat=mocap_quat.reshape(model.nmocap, -1)) data = mjx.forward(model, data) return data