Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gocarx #18

Closed
wants to merge 97 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
47f6dcb
Add tricycle model and a test demo
wu-hy Jul 9, 2024
ce52d8e
Add two agents demo
wu-hy Jul 10, 2024
ed32c8c
Add new state yaw_rate
wu-hy Jul 17, 2024
d8fc83f
Add observe and metric function
wu-hy Jul 25, 2024
285b8bd
cleanup
wu-hy Jul 30, 2024
6cbf6fc
Adaptation of PPO
wu-hy Aug 8, 2024
47b521a
runnable PPO, bugs to be fixed
wu-hy Aug 26, 2024
e77efe5
add some debug module
wu-hy Sep 4, 2024
19621d2
add wandb video
wu-hy Sep 18, 2024
4bf2dee
debug log
wu-hy Sep 20, 2024
0732e79
fix tangent normalization bug
wu-hy Sep 20, 2024
f0775d6
refactor az
alezana Sep 20, 2024
2788c80
fix merge
alezana Sep 20, 2024
6e11db3
init
alezana Sep 23, 2024
94a28be
merge
wu-hy Sep 24, 2024
e0c140d
comment out tanh at the end of pi network
wu-hy Sep 26, 2024
e16244c
fix progression reward bug
wu-hy Sep 30, 2024
cadb452
merge branch 'rl-gokart-hanyu' into 'rl-tian'
jubilantrou Sep 30, 2024
46b669d
minors config
alezana Sep 30, 2024
bf965dd
minors
alezana Sep 30, 2024
54a786c
init
alezana Sep 23, 2024
f00c141
minors
alezana Oct 1, 2024
4c6f222
modifying code
jubilantrou Oct 1, 2024
c02d330
merge
wu-hy Oct 1, 2024
29267fd
reward log + rand init position
wu-hy Oct 3, 2024
ac57255
debug error on gpu
alezana Oct 4, 2024
94d7cb4
merge
alezana Oct 4, 2024
7f7fc57
runnable codes for waymax env
jubilantrou Oct 7, 2024
114e780
merge
wu-hy Oct 7, 2024
2c3fff2
use factory functions
jubilantrou Oct 8, 2024
78827e1
Merge branch 'rl-gokart-hanyu' into rl-tian
jubilantrou Oct 9, 2024
7cdfa37
rand init pos
wu-hy Oct 10, 2024
3a6c715
merge
wu-hy Oct 10, 2024
f3bb2c8
Merge branch 'rl-gokart-hanyu' into rl-tian
alezana Oct 10, 2024
64cc9c6
Merge pull request #2 from idsc-frazzoli/rl-tian
alezana Oct 10, 2024
309da47
deleting rl folder for transition to gocarx
alezana Oct 10, 2024
528d5cf
Update setup.py
alezana Oct 11, 2024
0a266a1
controllable_fields
alezana Oct 15, 2024
c9350a6
moved scripts
alezana Oct 15, 2024
d4ea084
Merge remote-tracking branch 'origin/rl-gokart-hanyu' into rl-gokart-…
alezana Oct 15, 2024
170a15b
minor update
wu-hy Oct 23, 2024
649f5d1
metrics in separate files 1
alezana Oct 25, 2024
d1d9282
modified some obs funtions and viz functions for env_waymax
jubilantrou Oct 29, 2024
b61242c
gokart metrics
wu-hy Oct 30, 2024
b4a4858
Merge branch 'rl-gokart-hanyu' of github.com:idsc-frazzoli/waymax int…
wu-hy Oct 30, 2024
0879740
typo
wu-hy Oct 30, 2024
432a977
add test cases
wu-hy Oct 30, 2024
d5bec87
add gokart offroad
wu-hy Oct 30, 2024
74b6dfb
refactored
alezana Oct 31, 2024
a1da7c7
minors
alezana Oct 31, 2024
7617306
use linear combination for reward
wu-hy Oct 31, 2024
1bf10c2
Merge branch 'rl-gokart-hanyu' of github.com:idsc-frazzoli/waymax int…
wu-hy Oct 31, 2024
80ee6f0
wip
alezana Oct 31, 2024
c96809d
eval rollout running
alezana Nov 1, 2024
97d443d
wip
alezana Nov 2, 2024
5b21322
better rollout with module's state + wip
alezana Nov 3, 2024
55fda8e
still debugging
wu-hy Nov 4, 2024
53ac0b0
added some metrics
alezana Nov 6, 2024
d35756b
minor
wu-hy Nov 8, 2024
5ead191
adapted waymaxenv to new wrapper
jubilantrou Nov 8, 2024
b1a6417
fix bug in orient reward
wu-hy Nov 12, 2024
f2959d3
dev
alezana Nov 12, 2024
922d7ef
get original limits
nicolaloi Nov 4, 2024
74d1aaa
dynamics from forces solver
nicolaloi Nov 12, 2024
f59081b
merged branch 'init/hanyu'
jubilantrou Nov 13, 2024
f70edd9
wrap velocity in orient reward
wu-hy Nov 13, 2024
9582b07
merge
wu-hy Nov 13, 2024
6246fc7
fix bug in orient reward
wu-hy Nov 13, 2024
f811a00
Merge pull request #5 from idsc-frazzoli/init/hanyu
alezana Nov 13, 2024
d2396b6
coherency fix for metric dimensions
alezana Nov 13, 2024
e6d5c40
minors
alezana Nov 14, 2024
a794f32
Merge pull request #6 from idsc-frazzoli/init/tian
alezana Nov 14, 2024
21844e4
new dynamics
nicolaloi Nov 12, 2024
22689fd
fixed dimension mismatch for multiplayer metrics and gokart ones,
alezana Nov 14, 2024
77592fa
fixed minibatch shuffling
alezana Nov 14, 2024
0c3c14a
review
nicolaloi Nov 15, 2024
af31000
az dynamic model PR review
alezana Nov 16, 2024
098a981
minor
alezana Nov 16, 2024
fa38d5c
minors
alezana Nov 17, 2024
5becb34
fix to make code run
nicolaloi Nov 18, 2024
2956053
model factory (original, forces, ignition)
nicolaloi Nov 20, 2024
9f0958d
test
wu-hy Nov 20, 2024
cee6a2b
Merge branch 'init/az_refactoring' into init/nicola
alezana Nov 20, 2024
cbeccfc
Merge pull request #7 from idsc-frazzoli/init/nicola
alezana Nov 21, 2024
4f318fe
faster_rendering (#10)
nicolaloi Nov 21, 2024
7233254
move waymax to gocarx
alezana Nov 22, 2024
3b1f527
Merge pull request #11 from idsc-frazzoli/init/waymax_move
alezana Nov 23, 2024
1971714
minor
alezana Nov 25, 2024
a0fa2c1
minor
alezana Dec 1, 2024
1e81e79
added a fixme for manually coded parameter
matteop65 Dec 4, 2024
80abca5
Nicola/rewards (#12)
nicolaloi Dec 16, 2024
6ba6282
merging
alezana Dec 16, 2024
0bf9af3
fix
alezana Dec 16, 2024
ebec15e
clean up and few bug fixes
alezana Dec 20, 2024
d214fca
viz is broken
alezana Dec 20, 2024
e233485
Merge pull request #4 from idsc-frazzoli/init/az_refactoring
alezana Dec 20, 2024
a387615
Az/patch001 (#17)
alezana Jan 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ name: CI

on:
push:
branches: [ "main", "staging" ]
branches: [ "main", "gocarx" ]
pull_request:
branches: [ "main", "staging" ]
branches: [ "main", "gocarx" ]

permissions:
contents: read
Expand Down Expand Up @@ -36,5 +36,5 @@ jobs:
pip install -e .
- name: Test with pytest
run: |
pytest
pytest --capture=no -v waymax

9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
dataset/training/training_tfexample.tfrecord-00000-of-01000
dataset/training/training_tfexample.tfrecord-00001-of-01000
/.vscode
docs/
wandb/
logs/
out/
__pycache__
*.egg-info
3 changes: 3 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
FROM ghcr.io/nvidia/jax:jax

CMD ["bash"]
21 changes: 21 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

cover_packages=waymax

out=out
tr=$(out)/test-results

junit=--junitxml=$(tr)/junit.xml
parallel=-n auto --dist=loadfile
extra=--capture=no -v

clean-test:
poetry run coverage erase
rm -rf $(tr) $(tr)

test: clean-test
mkdir -p $(tr)
poetry run pytest $(extra) $(junit) waymax

test-parallel: clean-test
mkdir -p $(tr)
poetry run pytest $(extra) $(junit) $(parallel) waymax
65 changes: 34 additions & 31 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,44 @@
# limitations under the License.

"""setup.py file for Waymax."""
from setuptools import find_packages
from setuptools import setup

from setuptools import find_packages, setup

__version__ = '0.1.0'


with open('README.md', encoding='utf-8') as f:
_long_description = f.read()
_long_description = f.read()

setup(
name='waymo-waymax',
version=__version__,
description='Waymo simulator for autonomous driving',
long_description=_long_description,
long_description_content_type='text/markdown',
author='Waymax team',
author_email='[email protected]',
python_requires='>=3.10',
packages=find_packages(),
install_requires=[
'numpy>=1.20',
'jax>=0.4.6',
'tensorflow>=2.11.0',
'chex>=0.1.6',
'dm_env>=1.6',
'flax>=0.6.7',
'matplotlib>=3.7.1',
'dm-tree>=0.1.8',
'immutabledict>=2.2.3',
'Pillow>=9.4.0',
'mediapy>=1.1.6',
'tqdm>=4.65.0',
'absl-py>=1.4.0',
],
url='https://github.com/waymo-research/waymax',
license='Apache-2.0',
name='waymax',
version=__version__,
description='Waymo simulator for autonomous driving',
long_description=_long_description,
long_description_content_type='text/markdown',
author='Waymax team',
author_email='[email protected]',
python_requires='>=3.10',
packages=find_packages(),
install_requires=[
'numpy>=1.20',
'jax>=0.4.6',
'jaxtyping',
'chex>=0.1.6',
'distrax>=0.1.5',
'tf-keras', # needed for distrax
'dm_env>=1.6',
'flax>=0.6.7',
'matplotlib<3.10',
'dm-tree>=0.1.8',
'immutabledict>=2.2.3',
'Pillow>=9.4.0',
'mediapy>=1.1.6',
'moviepy',
'imageio',
'tqdm>=4.65.0',
'absl-py>=1.4.0',
'pytest>=6.2.4',
"beartype"
],
url='https://github.com/waymo-research/waymax',
license='Apache-2.0',
)
22 changes: 20 additions & 2 deletions waymax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""Configs for Waymax Environments."""
import dataclasses
import enum
from typing import Optional, Sequence
from typing import Optional, Sequence, Callable

import jax


class CoordinateFrame(enum.Enum):
Expand Down Expand Up @@ -136,9 +138,25 @@ class LinearCombinationRewardConfig:
rewards: Dictionary of metric names to floats indicating the weight of each
metric to create a reward of a linear combination.
"""

rewards: dict[str, float]

@classmethod
def default_gokart(cls) -> 'LinearCombinationRewardConfig':
return cls(
rewards={"gokart_offroad":-4.0, "gokart_progress": 1.0},
)

@dataclasses.dataclass(frozen=True)
class LinearTransformedRewardConfig(LinearCombinationRewardConfig):
"""Config listing all metrics and their corresponding transform.

Attributes:
rewards: Dictionary of metric names to floats indicating the weight of each
metric to create a reward of a linear combination.
transform: Dictionary of metric names to functions that apply an additional transform to the metric
"""
transform: dict[str, Callable[[jax.Array], jax.Array]]


class ObjectType(enum.Enum):
"""Types of objects that can be controlled by Waymax."""
Expand Down
5 changes: 5 additions & 0 deletions waymax/datatypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from waymax.datatypes.action import Action
from waymax.datatypes.action import TrajectoryUpdate
from waymax.datatypes.action import GoKartTrajectoryUpdate
from waymax.datatypes.action import GokartAction
from waymax.datatypes.array import MaskedArray
from waymax.datatypes.array import PyTree
from waymax.datatypes.constant import TIME_INTERVAL
Expand All @@ -24,6 +26,7 @@
from waymax.datatypes.object_state import ObjectMetadata
from waymax.datatypes.object_state import ObjectTypeIds
from waymax.datatypes.object_state import Trajectory
from waymax.datatypes.object_state import GokartTrajectory
from waymax.datatypes.observation import ObjectPose2D
from waymax.datatypes.observation import Observation
from waymax.datatypes.observation import observation_from_state
Expand All @@ -46,8 +49,10 @@
from waymax.datatypes.roadgraph import MapElementIds
from waymax.datatypes.roadgraph import RoadgraphPoints
from waymax.datatypes.route import Paths
from waymax.datatypes.route import GoKartPaths
from waymax.datatypes.simulator_state import get_control_mask
from waymax.datatypes.simulator_state import SimulatorState
from waymax.datatypes.simulator_state import GoKartSimState
from waymax.datatypes.simulator_state import update_state_by_log
from waymax.datatypes.traffic_lights import TrafficLights
from waymax.datatypes.traffic_lights import TrafficLightStates
134 changes: 133 additions & 1 deletion waymax/datatypes/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.

"""Dataclass definitions for dynamics models."""
from typing import Any
from typing import Any, Sequence

import chex
import jax
import jax.numpy as jnp

from waymax.datatypes import operations
from waymax.utils.classproperty import classproperty


# TODO: make Actions inherit from datatypes.MaskedArray.
Expand Down Expand Up @@ -112,4 +113,135 @@ def as_action(self) -> Action:
)
return Action(data=action, valid=self.valid)

@chex.dataclass
class GoKartTrajectoryUpdate(TrajectoryUpdate):
yaw_rate: jax.Array # (..., num_objects, 1)
acc_x: jax.Array # (..., num_objects, 1)
acc_y: jax.Array # (..., num_objects, 1)

def validate(self) -> None:
"""Validates shape and type."""
# Verifies that each element has the same dimensions.
chex.assert_equal_shape(
[self.x, self.y, self.yaw, self.vel_x, self.vel_y, self.yaw_rate, self.acc_x, self.acc_y, self.valid],
)
chex.assert_type(
[self.x, self.y, self.yaw, self.vel_x, self.vel_y, self.yaw_rate, self.acc_x, self.acc_y, self.valid],
[
jnp.float32,
jnp.float32,
jnp.float32,
jnp.float32,
jnp.float32,
jnp.float32,
jnp.float32,
jnp.float32,
jnp.bool_,
],
)

def as_action(self) -> Action:
"""Returns this trajectory update as a 8D Action for StateDynamics.

Returns:
An action data structure with data of shape (..., 8) containing
x, y, yaw, vel_x, vel_y, yaw_rate, acc_x and acc_y.
"""
action = jnp.concatenate(
[self.x, self.y, self.yaw, self.vel_x, self.vel_y, self.yaw_rate, self.acc_x, self.acc_y], axis=-1
)
return Action(data=action, valid=self.valid)


@chex.dataclass
class GokartAction:
"""
Data structure representing the action history of a gokart.

Attributes:

steering_angle: The steering angle of the gokart column data type float32.
acc_left: The left acceleration of the left wheel of the gokart at each time step of data type float32.
acc_right: The right acceleration of the left wheel of the gokart at each time step of data type float32.
"""

steering_angle: jax.Array # (..., num_objects, 1)
acc_left: jax.Array # (..., num_objects, 1)
acc_right: jax.Array # (..., num_objects, 1)

@property
def shape(self) -> tuple[int, ...]:
"""The Array shape of this history."""
return self.steering_angle.shape

@property
def num_actions(self) -> int:
"""The number of actions."""
return self.shape[-2]

@property
def num_timesteps(self) -> int:
"""The length of this history."""
return self.shape[-1]

@property
def acc_leftright(self) -> jax.Array:
"""Stacked AB action"""
return jnp.stack([self.acc_left, self.acc_right], axis=-1)

@property
def torque_vectoring(self) -> jax.Array:
"""Stacked Torque Vectoring indirect action (acc_right - acc_left)"""
return self.acc_right - self.acc_left

@classproperty
def action_fields(self) -> Sequence[str]:
"""Returns the action fields."""
#todo there are better ways to do this
return ["steering_angle", "acc_left", "acc_right"]

@classmethod
def zeros(cls, shape: Sequence[int]) -> "GokartAction":
"""Creates a Trajectory containing zeros of the specified shape."""
return cls(
steering_angle=jnp.zeros(shape, jnp.float32),
acc_left=jnp.zeros(shape, jnp.float32),
acc_right=jnp.zeros(shape, jnp.float32),
)

def __eq__(self, other: Any) -> bool:
return operations.compare_all_leaf_nodes(self, other)

def stack_fields(self, field_names: Sequence[str]) -> jax.Array:
"""Returns a concatenated version of a set of field names for Trajectory."""
return jnp.stack([getattr(self, field_name) for field_name in field_names], axis=-1)

def set_actions(self, action: Action, timestep: jax.typing.ArrayLike) -> "GokartAction":
"""Return a new action history updated at timestep with the new action."""
return self.replace(
steering_angle=self.steering_angle.at[..., timestep].set(action.data[0]),
acc_left=self.acc_left.at[..., timestep].set(action.data[1]),
acc_right=self.acc_right.at[..., timestep].set(action.data[2]),
)

def validate(self):
"""Validates shape and type."""
chex.assert_equal_shape(
[
self.steering_angle,
self.acc_left,
self.acc_right,
]
)
chex.assert_type(
[
self.steering_angle,
self.acc_left,
self.acc_right,
],
[
jnp.float32,
jnp.float32,
jnp.float32,
],
)
Loading
Loading