-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathhalf_cheetah.py
132 lines (108 loc) · 4.44 KB
/
half_cheetah.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
from typing import Tuple
from brax import base
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
import jax
from jax import numpy as jnp
# This is based on original Half Cheetah environment from Brax
# https://github.com/google/brax/blob/main/brax/envs/half_cheetah.py
class Halfcheetah(PipelineEnv):
def __init__(
self,
forward_reward_weight=1.0,
ctrl_cost_weight=0.1,
reset_noise_scale=0.1,
exclude_current_positions_from_observation=False,
backend="mjx",
dense_reward: bool = False,
**kwargs
):
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'assets', "half_cheetah.xml")
sys = mjcf.load(path)
n_frames = 5
if backend in ["spring", "positional"]:
sys = sys.tree_replace({"opt.timestep": 0.003125})
n_frames = 16
gear = jnp.array([120, 90, 60, 120, 100, 100])
sys = sys.replace(actuator=sys.actuator.replace(gear=gear))
kwargs["n_frames"] = kwargs.get("n_frames", n_frames)
super().__init__(sys=sys, backend=backend, **kwargs)
self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
self.dense_reward = dense_reward
self.state_dim = 18
self.goal_indices = jnp.array([0])
self.goal_reach_thresh = 0.5
def reset(self, rng: jax.Array) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)
low, hi = -self._reset_noise_scale, self._reset_noise_scale
qpos = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=low, maxval=hi
)
qvel = hi * jax.random.normal(rng2, (self.sys.qd_size(),))
_, target = self._random_target(rng)
qpos = qpos.at[-1:].set(target)
qvel = qvel.at[-1:].set(0)
pipeline_state = self.pipeline_init(qpos, qvel)
obs = self._get_obs(pipeline_state)
reward, done, zero = jnp.zeros(3)
metrics = {
"x_position": zero,
"x_velocity": zero,
"reward_ctrl": zero,
"reward_run": zero,
"dist": zero,
"success": zero,
"success_easy": zero
}
state = State(pipeline_state, obs, reward, done, metrics)
return state
def step(self, state: State, action: jax.Array) -> State:
"""Runs one timestep of the environment's dynamics."""
pipeline_state0 = state.pipeline_state
pipeline_state = self.pipeline_step(pipeline_state0, action)
x_velocity = (
pipeline_state.x.pos[0, 0] - pipeline_state0.x.pos[0, 0]
) / self.dt
forward_reward = self._forward_reward_weight * x_velocity
ctrl_cost = self._ctrl_cost_weight * jnp.sum(jnp.square(action))
obs = self._get_obs(pipeline_state)
dist = jnp.linalg.norm(obs[:1] - obs[-1:])
success = jnp.array(dist < self.goal_reach_thresh, dtype=float)
success_easy = jnp.array(dist < 2., dtype=float)
if self.dense_reward:
reward = ctrl_cost - dist
else:
reward = success
state.metrics.update(
x_position=pipeline_state.x.pos[0, 0],
x_velocity=x_velocity,
reward_run=forward_reward,
reward_ctrl=-ctrl_cost,
dist=dist,
success=success,
success_easy=success_easy
)
return state.replace(
pipeline_state=pipeline_state, obs=obs, reward=reward
)
def _get_obs(self, pipeline_state: base.State) -> jax.Array:
"""Returns the environment observations."""
position = pipeline_state.q[:-1]
velocity = pipeline_state.qd[:-1]
target_pos = pipeline_state.x.pos[-1][:1]
if self._exclude_current_positions_from_observation:
position = position[1:]
return jnp.concatenate((position, velocity, target_pos))
def _random_target(self, rng: jax.Array) -> Tuple[jax.Array, jax.Array]:
"""Returns a target location in a random circle slightly above xy plane."""
rng, rng1 = jax.random.split(rng, 2)
dist = 5
target_x = dist
return rng, jnp.array([target_x])