From 50e155783ce81f74552f0d94d8b430c274a34714 Mon Sep 17 00:00:00 2001 From: Daphne Cornelisse Date: Fri, 10 Jan 2025 09:14:03 -0500 Subject: [PATCH] Improved visualizer and gym environment (cloning all tensors) --- pygpudrive/datatypes/info.py | 33 +++ pygpudrive/datatypes/roadgraph.py | 4 +- pygpudrive/datatypes/trajectory.py | 31 ++- pygpudrive/env/base_env.py | 24 +-- pygpudrive/env/config.py | 49 ++--- pygpudrive/env/dataset.py | 135 ++++++++++++ pygpudrive/env/env_torch.py | 154 +++++++++++--- pygpudrive/env/viz.py | 2 +- pygpudrive/visualize/core.py | 331 +++++++++++++++++++---------- pygpudrive/visualize/utils.py | 181 ++++++++++++++++ 10 files changed, 735 insertions(+), 209 deletions(-) create mode 100644 pygpudrive/datatypes/info.py create mode 100644 pygpudrive/env/dataset.py diff --git a/pygpudrive/datatypes/info.py b/pygpudrive/datatypes/info.py new file mode 100644 index 00000000..a253b33e --- /dev/null +++ b/pygpudrive/datatypes/info.py @@ -0,0 +1,33 @@ +import torch +import gpudrive + + +class Info: + """A class to represent the information about the state of the environment. + Initialized from info_tensor (src/bindings) of shape (num_worlds, max_agents_in_scene, 5). + For details, see `Info` in src/types.hpp. + """ + + def __init__(self, info_tensor: torch.Tensor): + """Initializes the ego state with an observation tensor.""" + self.off_road = info_tensor[:, :, 0] + self.collided = info_tensor[:, :, 1:3].sum(axis=2) + self.goal_achieved = info_tensor[:, :, 3] + + @classmethod + def from_tensor( + cls, + info_tensor: gpudrive.madrona.Tensor, + backend="torch", + device="cuda", + ): + """Creates an LocalEgoState from the agent_observation_tensor.""" + if backend == "torch": + return cls(info_tensor.to_torch().clone().to(device)) + elif backend == "jax": + raise NotImplementedError("JAX backend not implemented yet.") + + @property + def shape(self): + """Returns the shape of the info tensor (num_worlds, max_agents_in_scene).""" + return self.off_road.shape diff --git a/pygpudrive/datatypes/roadgraph.py b/pygpudrive/datatypes/roadgraph.py index ec0d49a6..8f900daf 100644 --- a/pygpudrive/datatypes/roadgraph.py +++ b/pygpudrive/datatypes/roadgraph.py @@ -76,7 +76,7 @@ def from_tensor( ): """Creates a GlobalRoadGraphPoints instance from a tensor.""" if backend == "torch": - return cls(roadgraph_tensor.to_torch().to(device)) + return cls(roadgraph_tensor.to_torch().clone().to(device)) elif backend == "jax": raise NotImplementedError("JAX backend not implemented yet.") @@ -148,7 +148,7 @@ def from_tensor( ): """Creates a GlobalRoadGraphPoints instance from a tensor.""" if backend == "torch": - return cls(local_roadgraph_tensor.to_torch().to(device)) + return cls(local_roadgraph_tensor.to_torch().clone().to(device)) elif backend == "jax": raise NotImplementedError("JAX backend not implemented yet.") diff --git a/pygpudrive/datatypes/trajectory.py b/pygpudrive/datatypes/trajectory.py index 2fc6ce0f..defe6c64 100644 --- a/pygpudrive/datatypes/trajectory.py +++ b/pygpudrive/datatypes/trajectory.py @@ -4,6 +4,7 @@ TRAJ_LEN = 91 # Length of the logged trajectory + @dataclass class LogTrajectory: """A class to represent the logged human trajectories. Initialized from `expert_trajectory_tensor` (src/bindings.cpp). @@ -16,25 +17,35 @@ class LogTrajectory: actions: Expert actions performed by the agent(s) across the trajectory. """ - def __init__(self, raw_logs: torch.Tensor, num_worlds: int, max_agents: int): - """Initializes the expert trajectory with an observation tensor.""" + def __init__( + self, raw_logs: torch.Tensor, num_worlds: int, max_agents: int + ): + """Initializes the expert trajectory with an observation tensor.""" self.pos_xy = raw_logs[:, :, : 2 * TRAJ_LEN].view( num_worlds, max_agents, TRAJ_LEN, -1 ) - self.vel_xy = raw_logs[:, :, 2 * TRAJ_LEN: 4 * TRAJ_LEN].view( - num_worlds, max_agents, TRAJ_LEN, -1 - ) - self.yaw = raw_logs[:, :, 4 * TRAJ_LEN: 5 * TRAJ_LEN].view( + self.vel_xy = raw_logs[:, :, 2 * TRAJ_LEN : 4 * TRAJ_LEN].view( num_worlds, max_agents, TRAJ_LEN, -1 ) - self.inferred_actions = raw_logs[:, :, 6 * TRAJ_LEN: 16 * TRAJ_LEN].view( + self.yaw = raw_logs[:, :, 4 * TRAJ_LEN : 5 * TRAJ_LEN].view( num_worlds, max_agents, TRAJ_LEN, -1 ) + self.inferred_actions = raw_logs[ + :, :, 6 * TRAJ_LEN : 16 * TRAJ_LEN + ].view(num_worlds, max_agents, TRAJ_LEN, -1) @classmethod - def from_tensor(cls, expert_traj_tensor: gpudrive.madrona.Tensor, num_worlds: int, max_agents: int, backend="torch"): + def from_tensor( + cls, + expert_traj_tensor: gpudrive.madrona.Tensor, + num_worlds: int, + max_agents: int, + backend="torch", + ): """Creates an LogTrajectory from a tensor.""" if backend == "torch": - return cls(expert_traj_tensor.to_torch(), num_worlds, max_agents) # Pass the entire tensor + return cls( + expert_traj_tensor.to_torch().clone(), num_worlds, max_agents + ) # Pass the entire tensor elif backend == "jax": - raise NotImplementedError("JAX backend not implemented yet.") \ No newline at end of file + raise NotImplementedError("JAX backend not implemented yet.") diff --git a/pygpudrive/env/base_env.py b/pygpudrive/env/base_env.py index 84453b17..1721aff1 100755 --- a/pygpudrive/env/base_env.py +++ b/pygpudrive/env/base_env.py @@ -152,7 +152,7 @@ def _setup_environment_parameters(self): return params - def _initialize_simulator(self, params, scene_config): + def _initialize_simulator(self, params, data_batch): """Initializes the simulation with the specified parameters. Args: @@ -167,11 +167,10 @@ def _initialize_simulator(self, params, scene_config): else gpudrive.madrona.ExecMode.CUDA ) - self.dataset = select_scenes(scene_config) sim = gpudrive.SimManager( exec_mode=exec_mode, gpu_id=0, - scenes=self.dataset, + scenes=data_batch, params=params, enable_batch_renderer=self.render_config and self.render_config.render_mode @@ -235,25 +234,6 @@ def _set_collision_behavior(self, params): ) return params - def reinit_scenarios(self, dataset: List[str]): - """Resample the scenes. - Args: - dataset (List[str]): List of scene names to resample. - - Returns: - None - """ - - # Resample the scenes - self.sim.set_maps(dataset) - - # Re-initialize the controlled agents mask - self.cont_agent_mask = self.get_controlled_agents_mask() - self.max_agent_count = self.cont_agent_mask.shape[1] - self.num_valid_controlled_agents_across_worlds = ( - self.cont_agent_mask.sum().item() - ) - def close(self): """Destroy the simulator and visualizer.""" del self.sim diff --git a/pygpudrive/env/config.py b/pygpudrive/env/config.py index 00899a1d..0323d0b4 100755 --- a/pygpudrive/env/config.py +++ b/pygpudrive/env/config.py @@ -43,9 +43,9 @@ class EnvConfig: # Road observation algorithm settings road_obs_algorithm: str = "linear" # Algorithm for road observations - obs_radius: float = 100.0 # Radius for road observations + obs_radius: float = 50.0 # Radius for road observations polyline_reduction_threshold: float = ( - 1.0 # Threshold for polyline reduction + 0.1 # Threshold for polyline reduction ) # Dynamics model @@ -56,7 +56,7 @@ class EnvConfig: # Action space settings (if discretized) # Classic or Invertible Bicycle dynamics model steer_actions: torch.Tensor = torch.round( - torch.linspace(-torch.pi, torch.pi, 36), decimals=3 + torch.linspace(-torch.pi, torch.pi, 42), decimals=3 ) accel_actions: torch.Tensor = torch.round( torch.linspace(-4.0, 4.0, 16), decimals=3 @@ -93,7 +93,7 @@ class EnvConfig: reward_type: str = "sparse_on_goal_achieved" # Alternatively, "weighted_combination", "distance_to_logs" dist_to_goal_threshold: float = ( - 3.0 # Radius around goal considered as "goal achieved" + 2.0 # Radius around goal considered as "goal achieved" ) # C++ and Python shared settings (modifiable via C++ codebase) @@ -111,10 +111,10 @@ class EnvConfig: ) # Length of an episode in the simulator num_lidar_samples: int = gpudrive.numLidarSamples - - #Param to init all objects: + # Param to init all objects: init_all_objects: bool = False + class SelectionDiscipline(Enum): """Enum for selecting scenes discipline in dataset configuration.""" @@ -138,8 +138,10 @@ class SceneConfig: seed (Optional[int]): Seed for random scene selection. """ - path: str - num_scenes: int + batch_size: int # Number of scenes per batch (should be equal to number of worlds in the env). + dataset_size: int # Maximum number of files to include in the dataset. + path: str = None + num_scenes: int = None discipline: SelectionDiscipline = SelectionDiscipline.PAD_N k_unique_scenes: Optional[int] = None seed: Optional[int] = None @@ -148,8 +150,7 @@ class SceneConfig: class RenderMode(Enum): """Enum for specifying rendering mode.""" - PYGAME_ABSOLUTE = "pygame_absolute" - PYGAME_EGOCENTRIC = "pygame_egocentric" + MATPLOTLIB = "matplotlib" PYGAME_LIDAR = "pygame_lidar" MADRONA_RGB = "madrona_rgb" MADRONA_DEPTH = "madrona_depth" @@ -171,31 +172,15 @@ class MadronaOption(Enum): @dataclass class RenderConfig: - """Configuration settings for rendering the environment. - + """ + Configuration settings for rendering the environment. Attributes: - render_mode (RenderMode): The mode used for rendering the environment. - view_option (Enum): Rendering view option (e.g., RGB, human view). - resolution (Tuple[int, int]): Resolution of the rendered image. - line_thickness (int): Thickness of the road lines in the rendering. - draw_obj_idx (bool): Whether to draw object indices on objects. - obj_idx_font_size (int): Font size for object indices. - color_scheme (str): Color mode for the rendering ("light" or "dark"). + render_mode (RenderMode): The mode used for rendering the environment. Default is MATPLOTLIB. + view_option (MadronaOption): Rendering view option used for the Madrona viewer (e.g., agent or top-down view). """ - render_mode: RenderMode = RenderMode.PYGAME_ABSOLUTE - view_option: Enum = PygameOption.RGB + render_mode: RenderMode = RenderMode.MATPLOTLIB + view_option: Enum = None resolution: Tuple[int, int] = (1024, 1024) - line_thickness: int = 0.7 draw_obj_idx: bool = False obj_idx_font_size: int = 9 - color_scheme: str = "light" - - def __str__(self) -> str: - """Returns a string representation of the rendering configuration.""" - return ( - f"RenderMode: {self.render_mode.value}, ViewOption: {self.view_option.value}, " - f"Resolution: {self.resolution}, LineThickness: {self.line_thickness}, " - f"DrawObjectIdx: {self.draw_obj_idx}, ObjectIdxFontSize: {self.obj_idx_font_size}, " - f"ColorScheme: {self.color_scheme}" - ) diff --git a/pygpudrive/env/dataset.py b/pygpudrive/env/dataset.py new file mode 100644 index 00000000..39557d2a --- /dev/null +++ b/pygpudrive/env/dataset.py @@ -0,0 +1,135 @@ +from dataclasses import dataclass +from typing import Iterator, List +import os +import random + + +@dataclass +class SceneDataLoader: + root: str + batch_size: int + dataset_size: int + sample_with_replacement: bool = False + file_prefix: str = "tfrecord" + seed: int = 42 + shuffle: bool = False + + """ + A data loader for sampling batches of traffic scenarios from a directory of files. + + Attributes: + root (str): Path to the directory containing scene files. + batch_size (int): Number of scenes per batch (usually equal to number of worlds in the env). + dataset_size (int): Maximum number of files to include in the dataset. + sample_with_replacement (bool): Whether to sample files with replacement. + file_prefix (str): Prefix for scene files to include in the dataset. + seed (int): Seed for random number generator to ensure reproducibility. + shuffle (bool): Whether to shuffle the dataset before batching. + """ + + def __post_init__(self): + # Validate the path + if not os.path.exists(self.root): + raise FileNotFoundError( + f"The specified path does not exist: {self.root}" + ) + + # Set the random seed for reproducibility + self.random_gen = random.Random(self.seed) + + # Create the dataset from valid files in the directory + self.dataset = [ + os.path.join(self.root, scene) + for scene in sorted(os.listdir(self.root)) + if scene.startswith(self.file_prefix) + ] + + # Adjust dataset size based on the provided dataset_size + self.dataset = self.dataset[ + : min(self.dataset_size, len(self.dataset)) + ] + + # If dataset_size < batch_size, repeat the dataset until it matches the batch size + if self.dataset_size < self.batch_size: + repeat_count = (self.batch_size // self.dataset_size) + 1 + self.dataset *= repeat_count + self.dataset = self.dataset[: self.batch_size] + + # Shuffle the dataset if required + if self.shuffle: + self.random_gen.shuffle(self.dataset) + + # Initialize state for iteration + self._reset_indices() + + def _reset_indices(self): + """Reset indices for sampling.""" + if self.sample_with_replacement: + self.indices = [ + self.random_gen.randint(0, len(self.dataset) - 1) + for _ in range(len(self.dataset)) + ] + else: + self.indices = list(range(len(self.dataset))) + self.current_index = 0 + + def __iter__(self) -> Iterator[List[str]]: + self._reset_indices() + return self + + def __len__(self): + """Get the number of batches in the dataloader.""" + return len(self.dataset) // self.batch_size + + def __next__(self) -> List[str]: + if self.sample_with_replacement: + # Get the next batch of "deterministic" random indices + batch_indices = self.indices[ + self.current_index : self.current_index + self.batch_size + ] + self.current_index += self.batch_size + + if self.current_index > len(self.indices): + raise StopIteration + + # Retrieve the corresponding scenes + batch = [self.dataset[i] for i in batch_indices] + else: + if self.current_index >= len(self.indices): + raise StopIteration + + # Get the next batch of indices + end_index = min( + self.current_index + self.batch_size, len(self.indices) + ) + batch_indices = self.indices[self.current_index : end_index] + self.current_index = end_index + + # Retrieve the corresponding scenes + batch = [self.dataset[i] for i in batch_indices] + + return batch + + +# Example usage +if __name__ == "__main__": + from pprint import pprint + + data_loader = SceneDataLoader( + root="data/processed/training", + batch_size=2, + dataset_size=2, + sample_with_replacement=True, # Sampling with replacement + shuffle=False, # Shuffle the dataset before batching + ) + + print("\nDataset") + pprint(data_loader.dataset[:5]) + + print("\nBatch 1") + batch = next(iter(data_loader)) + pprint(batch) + + print("\nBatch 2") + batch = next(iter(data_loader)) + pprint(batch) diff --git a/pygpudrive/env/env_torch.py b/pygpudrive/env/env_torch.py index af425970..4c66b7da 100755 --- a/pygpudrive/env/env_torch.py +++ b/pygpudrive/env/env_torch.py @@ -1,12 +1,9 @@ -"""Base Gym Environment that interfaces with the GPU Drive simulator.""" +"""Troch gymnasium environment that interfaces with the GPU Drive simulator.""" from gymnasium.spaces import Box, Discrete, Tuple import numpy as np import torch -import gpudrive -import imageio from itertools import product - from pygpudrive.env.config import EnvConfig, RenderConfig, SceneConfig from pygpudrive.env.base_env import GPUDriveGymEnv @@ -18,8 +15,10 @@ ) from pygpudrive.datatypes.trajectory import LogTrajectory from pygpudrive.datatypes.roadgraph import LocalRoadGraphPoints +from pygpudrive.datatypes.info import Info from pygpudrive.visualize.core import MatplotlibVisualizer +from pygpudrive.env.dataset import SceneDataLoader class GPUDriveTorchEnv(GPUDriveGymEnv): @@ -28,7 +27,7 @@ class GPUDriveTorchEnv(GPUDriveGymEnv): def __init__( self, config, - scene_config, + data_loader, max_cont_agents, device="cuda", action_type="discrete", @@ -37,8 +36,8 @@ def __init__( ): # Initialization of environment configurations self.config = config - self.scene_config = scene_config - self.num_worlds = scene_config.num_scenes + self.data_loader = data_loader + self.num_worlds = data_loader.batch_size self.max_cont_agents = max_cont_agents self.device = device self.render_config = render_config @@ -47,8 +46,14 @@ def __init__( # Environment parameter setup params = self._setup_environment_parameters() - # Initialize simulator with parameters - self.sim = self._initialize_simulator(params, scene_config) + # Initialize the iterator once + self.data_iterator = iter(self.data_loader) + + # Get the initial data batch + self.data_batch = next(self.data_iterator) + + # Initialize simulator + self.sim = self._initialize_simulator(params, self.data_batch) # Controlled agents setup self.cont_agent_mask = self.get_controlled_agents_mask() @@ -68,6 +73,7 @@ def __init__( # Rendering setup self.vis = MatplotlibVisualizer( sim_object=self.sim, + controlled_agent_mask=self.cont_agent_mask, goal_radius=self.config.dist_to_goal_threshold, backend=self.backend, num_worlds=self.num_worlds, @@ -81,15 +87,19 @@ def reset(self): return self.get_obs() def get_dones(self): - return self.sim.done_tensor().to_torch().squeeze(dim=2).to(torch.float) - - def get_infos(self): return ( - self.sim.info_tensor() + self.sim.done_tensor() .to_torch() + .clone() .squeeze(dim=2) .to(torch.float) - .to(self.device) + ) + + def get_infos(self): + return Info.from_tensor( + self.sim.info_tensor(), + backend=self.backend, + device=self.device, ) def get_rewards( @@ -109,14 +119,14 @@ def get_rewards( The importance of each component is determined by the weights. """ if self.config.reward_type == "sparse_on_goal_achieved": - return self.sim.reward_tensor().to_torch().squeeze(dim=2) + return self.sim.reward_tensor().to_torch().clone().squeeze(dim=2) elif self.config.reward_type == "weighted_combination": # Return the weighted combination of the reward components - info_tensor = self.sim.info_tensor().to_torch() + info_tensor = self.sim.info_tensor().to_torch().clone() off_road = info_tensor[:, :, 0].to(torch.float) - # True if the vehicle collided with another road object + # True if the vehicle is in collision with another road object # (i.e. a cyclist or pedestrian) collided = info_tensor[:, :, 1:3].to(torch.float).sum(axis=2) goal_achieved = info_tensor[:, :, 3].to(torch.float) @@ -133,7 +143,7 @@ def get_rewards( # Reward based on distance to logs and penalty for collision # Return the weighted combination of the reward components - info_tensor = self.sim.info_tensor().to_torch() + info_tensor = self.sim.info_tensor().to_torch().clone() off_road = info_tensor[:, :, 0].to(torch.float) # True if the vehicle collided with another road object @@ -473,10 +483,41 @@ def get_obs(self): def get_controlled_agents_mask(self): """Get the control mask.""" - return (self.sim.controlled_state_tensor().to_torch() == 1).squeeze( - axis=2 + return ( + self.sim.controlled_state_tensor().to_torch().clone() == 1 + ).squeeze(axis=2) + + def swap_data_batch(self, data_batch=None): + """ + Swap the current data batch in the simulator with a new one + and reinitialize dependent attributes. + """ + + if data_batch is None: # Sample new data batch from the data loader + self.data_batch = next(self.data_iterator) + else: + self.data_batch = data_batch + + # Validate that the number of worlds (envs) matches the batch size + if len(self.data_batch) != self.num_worlds: + raise ValueError( + f"Data batch size ({len(self.data_batch)}) does not match " + f"the expected number of worlds ({self.num_worlds})." + ) + + # Update the simulator with the new data + self.sim.set_maps(self.data_batch) + + # Reinitialize the mask for controlled agents + self.cont_agent_mask = self.get_controlled_agents_mask() + self.max_agent_count = self.cont_agent_mask.shape[1] + self.num_valid_controlled_agents_across_worlds = ( + self.cont_agent_mask.sum().item() ) + # Reset static scenario data for the visualizer + self.vis.initialize_static_scenario_data(self.cont_agent_mask) + def get_expert_actions(self): """Get expert actions for the full trajectories across worlds. @@ -545,43 +586,86 @@ def get_expert_actions(self): if __name__ == "__main__": - # CONFIGURE - TOTAL_STEPS = 90 - MAX_CONTROLLED_AGENTS = 32 - NUM_WORLDS = 1 + from pygpudrive.visualize.utils import img_from_fig + import mediapy as media env_config = EnvConfig(dynamics_model="delta_local") render_config = RenderConfig() - scene_config = SceneConfig("data/processed/training", NUM_WORLDS) + data_config = SceneConfig(batch_size=2, dataset_size=1000) + + # Create data loader + train_loader = SceneDataLoader( + root="data/processed/training", + batch_size=data_config.batch_size, + dataset_size=data_config.dataset_size, + sample_with_replacement=True, + ) - # MAKE ENV + # Make env env = GPUDriveTorchEnv( config=env_config, - scene_config=scene_config, - max_cont_agents=MAX_CONTROLLED_AGENTS, # Number of agents to control + data_loader=train_loader, + max_cont_agents=128, # Number of agents to control device="cpu", - render_config=render_config, ) - # RUN + print(f"dataset: {env.data_batch}") + + # Rollout obs = env.reset() - frames = [] + + print(f"controlled agents mask: {env.cont_agent_mask.sum()}") + + sim_frames = [] + agent_obs_frames = [] expert_actions, _, _, _ = env.get_expert_actions() - for t in range(TOTAL_STEPS): + env_idx = 0 + + for t in range(10): print(f"Step: {t}") # Step the environment env.step_dynamics(expert_actions[:, :, t, :]) - frames.append(env.render()) + highlight_agent = torch.where(env.cont_agent_mask[env_idx, :])[0][ + -1 + ].item() + + # Make video + sim_states = env.vis.plot_simulator_state( + env_indices=[env_idx], + zoom_radius=50, + time_steps=[t], + center_agent_indices=[highlight_agent], + ) + + agent_obs = env.vis.plot_agent_observation( + env_idx=env_idx, + agent_idx=highlight_agent, + figsize=(10, 10), + ) + + # sim_states[0].savefig(f"sim_state.png") + # agent_obs.savefig(f"agent_obs.png") + + sim_frames.append(img_from_fig(sim_states[0])) + agent_obs_frames.append(img_from_fig(agent_obs)) obs = env.get_obs() reward = env.get_rewards() done = env.get_dones() + info = env.get_infos() - # import imageio - imageio.mimsave("world1.gif", np.array(frames)) + if done[0, highlight_agent].bool(): + break env.close() + + media.write_video( + "sim_video.gif", np.array(sim_frames), fps=10, codec="gif" + ) + media.write_video( + "obs_video.gif", np.array(agent_obs_frames), fps=10, codec="gif" + ) diff --git a/pygpudrive/env/viz.py b/pygpudrive/env/viz.py index 6fc16f6f..76595d03 100644 --- a/pygpudrive/env/viz.py +++ b/pygpudrive/env/viz.py @@ -4,7 +4,7 @@ import math import gpudrive -from pygpudrive.env.config import MadronaOption, PygameOption, RenderMode +from pygpudrive.env.config import MadronaOption, RenderMode # AGENT COLORS PINK = (255, 105, 180) diff --git a/pygpudrive/visualize/core.py b/pygpudrive/visualize/core.py index 37283db6..02fc2fff 100644 --- a/pygpudrive/visualize/core.py +++ b/pygpudrive/visualize/core.py @@ -1,11 +1,10 @@ -import os import torch -import math import matplotlib from typing import Tuple, Optional, List, Dict, Any, Union import matplotlib.pyplot as plt from matplotlib.patches import Circle import numpy as np + import gpudrive from pygpudrive.visualize import utils from pygpudrive.datatypes.roadgraph import ( @@ -33,6 +32,7 @@ class MatplotlibVisualizer: def __init__( self, sim_object, + controlled_agent_mask, goal_radius, backend: str, num_worlds: int, @@ -42,28 +42,57 @@ def __init__( self.sim_object = sim_object self.backend = backend self.device = "cpu" - self.controlled_agents = self.get_controlled_agents_mask() self.goal_radius = goal_radius self.num_worlds = num_worlds self.render_config = render_config + self.figsize = (10, 10) self.env_config = env_config + self.initialize_static_scenario_data(controlled_agent_mask) + + def initialize_static_scenario_data(self, controlled_agent_mask): + """ + Initialize key information for visualization based on the + current batch of scenarios. + """ + self.response_type = ResponseType.from_tensor( + tensor=self.sim_object.response_type_tensor(), + backend=self.backend, + device=self.device, + ) + self.global_roadgraph = GlobalRoadGraphPoints.from_tensor( + roadgraph_tensor=self.sim_object.map_observation_tensor(), + backend=self.backend, + device=self.device, + ) + self.controlled_agent_mask = controlled_agent_mask.to(self.device) - def get_controlled_agents_mask(self): - """Get the control mask.""" - return ( - (self.sim_object.controlled_state_tensor().to_torch() == 1) - .squeeze(axis=2) - .to(self.device) + self.log_trajectory = LogTrajectory.from_tensor( + self.sim_object.expert_trajectory_tensor(), + self.num_worlds, + self.controlled_agent_mask.shape[1], + backend=self.backend, ) + # Cache pre-rendered road graphs for all environments + # self.cached_roadgraphs = [] + # for env_idx in range(self.controlled_agent_mask.shape[0]): + # fig, ax = plt.subplots(figsize=self.figsize) + # self._plot_roadgraph( + # road_graph=self.global_roadgraph, + # env_idx=env_idx, + # ax=ax, + # line_width_scale=1.0, + # marker_size_scale=1.0, + # ) + # self.cached_roadgraphs.append(fig) + # plt.close(fig) + def plot_simulator_state( self, env_indices: List[int], time_steps: Optional[List[int]] = None, center_agent_indices: Optional[List[int]] = None, - figsize: Tuple[int, int] = (15, 15), zoom_radius: int = 100, - return_single_figure: bool = False, plot_log_replay_trajectory: bool = False, ): """ @@ -88,73 +117,53 @@ def plot_simulator_state( env_indices ) # Default to None for all - # Extract data for all environments - global_roadgraph = GlobalRoadGraphPoints.from_tensor( - roadgraph_tensor=self.sim_object.map_observation_tensor(), - backend=self.backend, - device=self.device, - ) + # Changes at every time step global_agent_states = GlobalEgoState.from_tensor( self.sim_object.absolute_self_observation_tensor(), backend=self.backend, device=self.device, ) - response_type = ResponseType.from_tensor( - tensor=self.sim_object.response_type_tensor(), - backend=self.backend, - device=self.device, - ) - - agent_infos = self.sim_object.info_tensor().to_torch().to(self.device) - if plot_log_replay_trajectory: - log_trajectory = LogTrajectory.from_tensor( - self.sim_object.expert_trajectory_tensor(), - self.num_worlds, - self.controlled_agents.shape[1], - backend=self.backend, - ) - - figs = [] # Store all figures if returning multiple - - if return_single_figure: - # Calculate rows and columns for square layout - num_envs = len(env_indices) - num_rows = math.ceil(math.sqrt(num_envs)) - num_cols = math.ceil(num_envs / num_rows) + agent_infos = ( + self.sim_object.info_tensor().to_torch().clone().to(self.device) + ) - total_figsize = (figsize[0] * num_cols, figsize[1] * num_rows) - fig, axes = plt.subplots( - nrows=num_rows, - ncols=num_cols, - figsize=total_figsize, - squeeze=False, - ) - axes = axes.flatten() - else: - axes = [None] * len(env_indices) + figs = [] # Calculate scale factors based on figure size - max_fig_size = max(figsize) - marker_scale = max_fig_size / 15 # Adjust this factor as needed - line_width_scale = max_fig_size / 15 # Adjust this factor as needed + marker_scale = max(self.figsize) / 15 + line_width_scale = max(self.figsize) / 15 # Iterate over each environment index for idx, (env_idx, time_step, center_agent_idx) in enumerate( zip(env_indices, time_steps, center_agent_indices) ): - if return_single_figure: - ax = axes[idx] - ax.clear() # Clear any previous plots - ax.set_aspect("equal", adjustable="box") - else: - fig, ax = plt.subplots(figsize=figsize) - ax.set_aspect("equal", adjustable="box") - ax.clear() - figs.append(fig) + + # Initialize figure and axes from cached road graph + fig, ax = plt.subplots(figsize=self.figsize) + ax.clear() # Clear any existing content + ax.set_aspect("equal", adjustable="box") + figs.append(fig) # Add the new figure + plt.close(fig) # Close the figure to prevent carryover + + # Render the pre-cached road graph for the current environment + # cached_roadgraph_array = utils.bg_img_from_fig(self.cached_roadgraphs[env_idx]) + # ax.imshow( + # cached_roadgraph_array, + # origin="upper", + # extent=(-100, 100, -100, 100), # Stretch to full plot + # zorder=0, # Draw as background + # ) + + # Explicitly set the axis limits to match your coordinates + # cached_ax.set_xlim(-100, 100) + # cached_ax.set_ylim(-100, 100) + + # Remove axes + # cached_ax.axis('off') # Get control mask and omit out-of-bound agents (dead agents) - controlled = self.controlled_agents[env_idx, :] + controlled = self.controlled_agent_mask[env_idx, :] controlled_live = controlled & ( torch.abs(global_agent_states.pos_x[env_idx, :]) < 1_000 ) @@ -167,7 +176,7 @@ def plot_simulator_state( # Draw the road graph self._plot_roadgraph( - road_graph=global_roadgraph, + road_graph=self.global_roadgraph, env_idx=env_idx, ax=ax, line_width_scale=line_width_scale, @@ -179,7 +188,7 @@ def plot_simulator_state( ax=ax, control_mask=controlled_live, env_idx=env_idx, - log_trajectory=log_trajectory, + log_trajectory=self.log_trajectory, line_width_scale=line_width_scale, ) @@ -191,7 +200,7 @@ def plot_simulator_state( is_ok_mask=is_ok, is_offroad_mask=is_offroad, is_collided_mask=is_collided, - response_type=response_type, + response_type=self.response_type, alpha=1.0, line_width_scale=line_width_scale, marker_size_scale=marker_scale, @@ -232,16 +241,7 @@ def plot_simulator_state( ax.set_xlim(center_x - zoom_radius, center_x + zoom_radius) ax.set_ylim(center_y - zoom_radius, center_y + zoom_radius) - ax.set_xticks([]) - ax.set_yticks([]) - - if return_single_figure: - for ax in axes[len(env_indices) :]: - ax.axis("off") # Hide unused subplots - plt.tight_layout() - return fig - else: - return figs + return figs def _plot_log_replay_trajectory( self, @@ -288,6 +288,9 @@ def _plot_roadgraph( road_point_type == int(gpudrive.EntityType.RoadEdge) or road_point_type == int(gpudrive.EntityType.RoadLine) or road_point_type == int(gpudrive.EntityType.RoadLane) + or road_point_type == int(gpudrive.EntityType.SpeedBump) + or road_point_type == int(gpudrive.EntityType.StopSign) + or road_point_type == int(gpudrive.EntityType.CrossWalk) ): # Get coordinates and metadata x_coords = road_graph.x[env_idx, road_mask].tolist() @@ -295,35 +298,84 @@ def _plot_roadgraph( segment_lengths = road_graph.segment_length[ env_idx, road_mask ].tolist() + segment_widths = road_graph.segment_width[ + env_idx, road_mask + ].tolist() segment_orientations = road_graph.orientation[ env_idx, road_mask ].tolist() - # Compute and draw road edges using start and end points - for x, y, length, orientation in zip( - x_coords, - y_coords, - segment_lengths, - segment_orientations, + if ( + road_point_type == int(gpudrive.EntityType.RoadEdge) + or road_point_type == int(gpudrive.EntityType.RoadLine) + or road_point_type == int(gpudrive.EntityType.RoadLane) ): - start, end = self._get_endpoints( - x, y, length, orientation + # Compute and draw road edges using start and end points + for x, y, length, orientation in zip( + x_coords, + y_coords, + segment_lengths, + segment_orientations, + ): + start, end = self._get_endpoints( + x, y, length, orientation + ) + + # Plot the road edge as a line + if road_point_type == int( + gpudrive.EntityType.RoadEdge + ): + line_width = 1.1 * line_width_scale + else: + line_width = 0.75 * line_width_scale + + ax.plot( + [start[0], end[0]], + [start[1], end[1]], + color=ROAD_GRAPH_COLORS[road_point_type], + linewidth=line_width, + ) + + elif road_point_type == int(gpudrive.EntityType.SpeedBump): + utils.plot_speed_bumps( + x_coords, + y_coords, + segment_lengths, + segment_widths, + segment_orientations, + ax, ) - if road_point_type == int( - gpudrive.EntityType.RoadEdge + elif road_point_type == int(gpudrive.EntityType.StopSign): + for x, y in zip(x_coords, y_coords): + point = np.array([x, y]) + utils.plot_stop_sign( + point=point, + ax=ax, + radius=1.5, + facecolor="xkcd:red", + edgecolor="none", + linewidth=3.0, + alpha=0.8, + ) + elif road_point_type == int(gpudrive.EntityType.CrossWalk): + for x, y, length, width, orientation in zip( + x_coords, + y_coords, + segment_lengths, + segment_widths, + segment_orientations, ): - line_width = 1.1 * line_width_scale - - else: - line_width = 0.75 * line_width_scale - - ax.plot( - [start[0], end[0]], - [start[1], end[1]], - color=ROAD_GRAPH_COLORS[road_point_type], - linewidth=line_width, - ) + points = self._get_corners_polygon( + x, y, length, width, orientation + ) + utils.plot_crosswalk( + points=points, + ax=ax, + facecolor="none", + edgecolor="xkcd:bluish grey", + alpha=0.4, + ) else: # Dots for other road point types @@ -498,7 +550,7 @@ def _plot_filtered_agent_bounding_boxes( # Plot human_replay agents (those that are static or expert-controlled) log_replay = ( response_type.static[env_idx, :] | response_type.moving[env_idx, :] - ) & ~self.controlled_agents[env_idx, :] + ) & ~self.controlled_agent_mask[env_idx, :] pos_x = agent_states.pos_x[env_idx, log_replay] pos_y = agent_states.pos_y[env_idx, log_replay] @@ -510,8 +562,12 @@ def _plot_filtered_agent_bounding_boxes( valid_mask = ( (torch.abs(pos_x) < OUT_OF_BOUNDS) & (torch.abs(pos_y) < OUT_OF_BOUNDS) - & (vehicle_length < 15) - & (vehicle_width < 10) + & ( + (vehicle_length > 0.5) + & (vehicle_length < 15) + & (vehicle_width > 0.5) + & (vehicle_width < 15) + ) ) # Filter valid static agent attributes @@ -575,7 +631,6 @@ def plot_agent_observation( fig, ax = plt.subplots(figsize=figsize) ax.clear() # Clear any previous plots ax.set_aspect("equal", adjustable="box") - ax.set_title(f"Observation agent: {agent_idx}", y=1.05) # Plot roadgraph if provided if observation_roadgraph is not None: @@ -584,14 +639,76 @@ def plot_agent_observation( observation_roadgraph.type[env_idx, agent_idx, :] == road_type ) + + # Extract relevant roadgraph data for plotting + x_points = observation_roadgraph.x[env_idx, agent_idx, mask] + y_points = observation_roadgraph.y[env_idx, agent_idx, mask] + orientations = observation_roadgraph.orientation[ + env_idx, agent_idx, mask + ] + segment_lengths = observation_roadgraph.segment_length[ + env_idx, agent_idx, mask + ] + widths = observation_roadgraph.segment_width[ + env_idx, agent_idx, mask + ] + + # Scatter plot for the points ax.scatter( - observation_roadgraph.x[env_idx, agent_idx, mask], - observation_roadgraph.y[env_idx, agent_idx, mask], + x_points, + y_points, c=[ROAD_GRAPH_COLORS[road_type]], - s=7, + s=8, label=type_name, ) + # Plot lines for road edges + for x, y, orientation, segment_length, width in zip( + x_points, y_points, orientations, segment_lengths, widths + ): + dx = segment_length * 0.5 * np.cos(orientation) + dy = segment_length * 0.5 * np.sin(orientation) + + # Calculate line endpoints for the road edge + x_start = x - dx + y_start = y - dy + x_end = x + dx + y_end = y + dy + + # Add width as a perpendicular offset + width_dx = width * 0.5 * np.sin(orientation) + width_dy = -width * 0.5 * np.cos(orientation) + + # Draw the road edge as a polygon (line with width) + ax.plot( + [x_start - width_dx, x_end - width_dx], + [y_start - width_dy, y_end - width_dy], + color=ROAD_GRAPH_COLORS[road_type], + alpha=0.5, + linewidth=1.0, + ) + ax.plot( + [x_start + width_dx, x_end + width_dx], + [y_start + width_dy, y_end + width_dy], + color=ROAD_GRAPH_COLORS[road_type], + alpha=0.5, + linewidth=1.0, + ) + ax.plot( + [x_start - width_dx, x_start + width_dx], + [y_start - width_dy, y_start + width_dy], + color=ROAD_GRAPH_COLORS[road_type], + alpha=0.5, + linewidth=1.0, + ) + ax.plot( + [x_end - width_dx, x_end + width_dx], + [y_end - width_dy, y_end + width_dy], + color=ROAD_GRAPH_COLORS[road_type], + alpha=0.5, + linewidth=1.0, + ) + # Plot partner agents if provided if observation_partner is not None: partner_positions = torch.stack( @@ -619,7 +736,7 @@ def plot_agent_observation( env_idx, agent_idx, :, : ].squeeze(), color=REL_OBS_OBJ_COLORS["other_agents"], - alpha=0.9, + alpha=1.0, ) if observation_ego is not None: @@ -680,8 +797,8 @@ def plot_agent_observation( observation_radius = Circle( (0, 0), radius=self.env_config.obs_radius, - color="#d9d9d9", - linewidth=1.5, + color="#000000", + linewidth=0.8, fill=False, linestyle="-", ) @@ -693,4 +810,4 @@ def plot_agent_observation( ax.set_xticks([]) ax.set_yticks([]) - return fig, ax + return fig diff --git a/pygpudrive/visualize/utils.py b/pygpudrive/visualize/utils.py index 5624268f..3d816fa9 100644 --- a/pygpudrive/visualize/utils.py +++ b/pygpudrive/visualize/utils.py @@ -10,10 +10,26 @@ import torch import matplotlib from typing import Tuple, Optional, List, Dict, Any, Union +from matplotlib.patches import Circle, Polygon, RegularPolygon from pygpudrive.visualize.color import ROAD_GRAPH_COLORS, ROAD_GRAPH_TYPE_NAMES +def bg_img_from_fig(fig: matplotlib.figure.Figure) -> np.ndarray: + """Returns a [H, W, 3] uint8 np image from fig.canvas.tostring_rgb().""" + fig.subplots_adjust( + left=0.0, bottom=0.0, right=1.0, top=1.0, wspace=0.0, hspace=0.0 + ) + fig.canvas.draw() + + # Extract image data + data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + img = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + + plt.close(fig) # Close the figure + return img + + def img_from_fig(fig: matplotlib.figure.Figure) -> np.ndarray: """Returns a [H, W, 3] uint8 np image from fig.canvas.tostring_rgb().""" # Display xticks and yticks and title. @@ -213,6 +229,7 @@ def plot_bounding_box( color=color, alpha=alpha, linestyle="-", + linewidth=2, label=label if i == 0 else None, ) else: @@ -254,4 +271,168 @@ def plot_bounding_box( alpha=alpha, linestyle="-", label=label, + linewidth=2, + ) + + +def get_corners_polygon(x, y, length, width, orientation): + """Calculate the four corners of a speed bump (can be any) polygon.""" + # Compute the direction vectors based on orientation + # print(length) + c = np.cos(orientation) + s = np.sin(orientation) + u = np.array((c, s)) # Unit vector along the orientation + ut = np.array((-s, c)) # Unit vector perpendicular to the orientation + + # Center point of the speed bump + pt = np.array([x, y]) + + # corners + tl = pt + (length / 2) * u - (width / 2) * ut + tr = pt + (length / 2) * u + (width / 2) * ut + br = pt - (length / 2) * u + (width / 2) * ut + bl = pt - (length / 2) * u - (width / 2) * ut + + return [tl.tolist(), tr.tolist(), br.tolist(), bl.tolist()] + + +def get_stripe_polygon( + x: float, + y: float, + length: float, + width: float, + orientation: float, + index: int, + num_stripes: int, +) -> np.ndarray: + + """Calculate the corners of a stripe within the speed bump polygon.""" + + # Compute the direction vectors + c = np.cos(orientation) + s = np.sin(orientation) + u = np.array([c, s]) # Unit vector along the orientation (lengthwise) + ut = np.array([-s, c]) # Perpendicular unit vector (widthwise) + + # Total stripe height along the width + stripe_width = length / num_stripes + half_length = length / 2 + half_width = width / 2 + + # Offset for the current stripe + offset_start = -half_length + index * stripe_width + offset_end = offset_start + stripe_width + + # Center of the speed bump + center = np.array([x, y]) + + # Calculate stripe corners + stripe_corners = [ + center + u * offset_start + ut * half_width, # Top-left + center + u * offset_start - ut * half_width, # Bottom-left + center + u * offset_end - ut * half_width, # Bottom-right + center + u * offset_end + ut * half_width, # Top-right + ] + + return np.array(stripe_corners) + + +def plot_speed_bumps( + x_coords: Union[float, np.ndarray], + y_coords: Union[float, np.ndarray], + segment_lengths: Union[float, torch.Tensor], + segment_widths: Union[float, torch.Tensor], + segment_orientations: Union[float, torch.Tensor], + ax: matplotlib.axes.Axes, + facecolor: str = None, + edgecolor: str = None, + alpha: float = None, +) -> None: + facecolor = "xkcd:goldenrod" + edgecolor = "xkcd:black" + alpha = 0.5 + for x, y, length, width, orientation in zip( + x_coords, + y_coords, + segment_lengths, + segment_widths, + segment_orientations, + ): + # method1: from waymax using hatch as diagonals + points = get_corners_polygon(x, y, length, width, orientation) + + p = Polygon( + points, + facecolor=facecolor, + edgecolor=edgecolor, + linewidth=0, + alpha=alpha, + hatch=r"//", + zorder=2, ) + + ax.add_patch(p) + + pass + + +def plot_stop_sign( + point: np.ndarray, + ax: matplotlib.axes.Axes, + radius: float = None, + facecolor: str = None, + edgecolor: str = None, + linewidth: float = None, + alpha: float = None, +) -> None: + # Default configurations for the stop sign + facecolor = "red" if facecolor is None else facecolor + edgecolor = "white" if edgecolor is None else edgecolor + linewidth = 1.5 if linewidth is None else linewidth + radius = 1.0 if radius is None else radius + alpha = 1.0 if alpha is None else alpha + + point = np.array(point).reshape(-1) + + p = RegularPolygon( + point, + numVertices=6, # For hexagonal stop sign + radius=radius, + facecolor=facecolor, + edgecolor=edgecolor, + linewidth=linewidth, + alpha=alpha, + zorder=2, + ) + ax.add_patch(p) + + +def plot_crosswalk( + points, + ax: plt.Axes = None, + facecolor: str = None, + edgecolor: str = None, + alpha: float = None, +): + if ax is None: + ax = plt.gca() + # override default config + facecolor = ( + crosswalk_config["facecolor"] if facecolor is None else facecolor + ) + edgecolor = ( + crosswalk_config["edgecolor"] if edgecolor is None else edgecolor + ) + alpha = crosswalk_config["alpha"] if alpha is None else alpha + + p = Polygon( + points, + facecolor=facecolor, + edgecolor=edgecolor, + linewidth=2, + alpha=alpha, + hatch=r"//", + zorder=2, + ) + + ax.add_patch(p)