Skip to content

Commit

Permalink
First version projecting trajectories on BEV map
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinSchmid7 committed Aug 18, 2023
1 parent 65dc87d commit 2afec5e
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 14 deletions.
46 changes: 46 additions & 0 deletions wild_visual_navigation/image_projector/image_projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,52 @@ def project_and_render(

return self.masks, image_overlay, projected_points, valid_points

def project_and_render_on_map(
self, pose_base_in_world: torch.tensor, points: torch.tensor, colors: torch.tensor, image: torch.tensor = None
):
"""Projects the points and returns an image with the projection
Args:
points: (torch.Tensor, dtype=torch.float32, shape=(B, N, 3)): B batches, of N input points in 3D space
colors: (torch.Tensor, rtype=torch.float32, shape=(B, 3))
Returns:
out_img (torch.tensor, dtype=torch.int64): Image with projected points
"""

# self.masks = self.masks * 0.0
B = self.camera.batch_size
C = 3 # RGB channel output
H = self.camera.height.item()
W = self.camera.width.item()
self.masks = torch.zeros((B, C, H, W), dtype=torch.float32, device=self.camera.camera_matrix.device)
image_overlay = image

T_BW = pose_base_in_world.inverse()
# Convert from fixed to base frame
points_B = transform_points(T_BW, points)

# Remove z dimension
# TODO: project footprint on gravity aligned plane
flat_points = points_B[:, :, :-1]

# Shift to grid map coordinates
flat_points = flat_points / 0.1 + 128

# Fill the mask
self.masks = draw_convex_polygon(self.masks, flat_points, colors)

# Draw on image (if applies)
if image is not None:
if len(image.shape) != 4:
image = image[None]
image_overlay = draw_convex_polygon(image, flat_points, colors)

# Return torch masks
self.masks[self.masks == 0.0] = torch.nan

return self.masks, image_overlay

def resize_image(self, image: torch.tensor):
return self.image_crop(image)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def add_mission_node(self, node: MissionNode, verbose: bool = False, update_feat

@accumulate_time
@torch.no_grad()
def add_proprio_node(self, pnode: ProprioceptionNode):
def add_proprio_node(self, pnode: ProprioceptionNode, projection_mode: str = "image"):
"""Adds a node to the proprioceptive graph to store proprioception
Args:
Expand Down Expand Up @@ -473,38 +473,45 @@ def add_proprio_node(self, pnode: ProprioceptionNode):
color = torch.ones((3,), device=self._device)

# New implementation
B = len(mission_nodes)
B = len(mission_nodes) # Number of mission nodes to project

# Prepare batches
K = torch.eye(4, device=self._device).repeat(B, 1, 1)
supervision_masks = torch.zeros(last_mission_node.supervision_mask.shape, device=self._device).repeat(
B, 1, 1, 1
)
pose_camera_in_world = torch.eye(4, device=self._device).repeat(B, 1, 1)
pose_base_in_world = torch.eye(4, device=self._device).repeat(B, 1, 1)

H = last_mission_node.image_projector.camera.height
W = last_mission_node.image_projector.camera.width
footprints = footprint.repeat(B, 1, 1)

for i, mnode in enumerate(mission_nodes):
K[i] = mnode.image_projector.camera.intrinsics
pose_camera_in_world[i] = mnode.pose_cam_in_world
pose_base_in_world[i] = mnode.pose_base_in_world

if (not hasattr(mnode, "supervision_mask")) or (mnode.supervision_mask is None):
continue
else:
supervision_masks[i] = mnode.supervision_mask
supervision_masks[i] = mnode.supervision_mask # Getting all the existing supervision masks

im = ImageProjector(K, H, W)
mask, _, _, _ = im.project_and_render(pose_camera_in_world, footprints, color)

if projection_mode == "image":
mask, _, _, _ = im.project_and_render(pose_camera_in_world, footprints, color) # Generating the new supervisiom mask to add
elif projection_mode == "map":
mask, _ = im.project_and_render_on_map(pose_base_in_world, footprints, color)

# Update traversability
mask = mask * pnode.traversability
supervision_masks = torch.fmin(supervision_masks, mask)
# mask = mask * pnode.traversability
supervision_masks = torch.fmin(supervision_masks, mask) # Adding the new mask to the supervision mask, using element-wise non-nan values

# Update supervision mask per node
for i, mnode in enumerate(mission_nodes):
mnode.supervision_mask = supervision_masks[i]
mnode.update_supervision_signal()
# mnode.update_supervision_signal() # Accumulate supervision signal, check if features are there

if self._mode == WVNMode.EXTRACT_LABELS:
p = os.path.join(
Expand All @@ -514,7 +521,6 @@ def add_proprio_node(self, pnode: ProprioceptionNode):
)
store = torch.nan_to_num(mnode.supervision_mask.nanmean(axis=0)) != 0
torch.save(store, p)

return True

def get_mission_nodes(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ robot_height: 0.3
robot_max_velocity: 0.8

# Traversability estimation params
traversability_radius: 5.0 # meters
traversability_radius: 10.0 # meters
image_graph_dist_thr: 0.2 # meters
proprio_graph_dist_thr: 0.1 # meters
network_input_image_height: 448 # 448
Expand Down Expand Up @@ -68,8 +68,8 @@ colormap: "RdYlBu"

print_image_callback_time: false
print_proprio_callback_time: false
log_time: true
log_memory: true
log_time: false
log_memory: false
log_confidence: true
verbose: false
debug_supervision_node_index_from_last: 10
Expand All @@ -78,3 +78,4 @@ use_debug_for_desired: true
extraction_store_folder: /home/rschmid/RosBags/output/6
exp: "nan"
use_binary_only: true
supervision_projection_mode: "map"
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
if torch.cuda.is_available():
torch.cuda.empty_cache()

torch.set_printoptions(edgeitems=200)


class WvnRosInterface:
def __init__(self):
Expand Down Expand Up @@ -277,6 +279,7 @@ def read_params(self):
# Select mode: # debug, online, extract_labels
self.use_debug_for_desired = rospy.get_param("~use_debug_for_desired") # Note: Unused parameter
self.use_binary_only = rospy.get_param("~use_binary_only") # Only extract binary labels, do not update traversability
self.supervision_projection_mode = rospy.get_param("~supervision_projection_mode")
self.mode = WVNMode.from_string(rospy.get_param("~mode", "debug"))
self.extraction_store_folder = rospy.get_param("~extraction_store_folder")

Expand Down Expand Up @@ -614,7 +617,7 @@ def robot_state_callback(self, state_msg, desired_twist_msg: TwistStamped):
)

# Add node to the graph
self.traversability_estimator.add_proprio_node(proprio_node)
self.traversability_estimator.add_proprio_node(proprio_node, projection_mode=self.supervision_projection_mode)

# if self.mode == WVNMode.DEBUG or self.mode == WVNMode.ONLINE:
# self.visualize_proprioception()
Expand Down Expand Up @@ -917,8 +920,6 @@ def visualize_proprioception(self):
print(f"number of points for footprint is {len(footprints_marker.points)}")
return

print("points", footprints_marker.points)

self.pub_graph_footprints.publish(footprints_marker)
self.pub_debug_proprio_graph.publish(proprio_graph_msg)

Expand Down

0 comments on commit 2afec5e

Please sign in to comment.