Skip to content

Commit

Permalink
common logger using mcap
Browse files Browse the repository at this point in the history
  • Loading branch information
pierfabre committed Jan 24, 2025
1 parent 458b426 commit 4e2ecc2
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 127 deletions.
81 changes: 55 additions & 26 deletions furuta/logger.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,64 @@
import typing as tp
from pathlib import Path
from typing import List

import matplotlib.pyplot as plt
import numpy as np
from mcap_protobuf.reader import read_protobuf_messages
from mcap_protobuf.writer import Writer

STATE = ["phi", "theta", "phi_dot", "theta_dot"]
from furuta.logging.protobuf.pendulum_state_pb2 import PendulumState
from furuta.utils import STATE, State


class SimpleLogger:
def __init__(self):
self.times: tp.List[float] = []
self.states: tp.List[np.ndarray] = []

def update(self, time: float, state: np.ndarray):
self.times.append(time)
self.states.append(state)

def save(self, directory: Path):
np.save(directory / "times.npy", self.times)
np.save(directory / "states.npy", np.array(self.states))

def load(self, directory: Path):
self.states = np.load(directory / "states.npy")
self.times = np.load(directory / "times.npy")

def plot(self):
plt.figure(1)
for i in range(len(STATE)):
plt.subplot(2, 2, i + 1)
plt.plot(self.times, np.array(self.states)[:, i])
plt.title(STATE[i])

def show(self):
def __init__(self, log_path: (str | Path)):
self.log_path = log_path

def start(self):
self.output_file = open(self.log_path, "wb")
self.mcap_writer = Writer(self.output_file)

def stop(self):
self.mcap_writer.finish()
self.output_file.close()

def update(self, time_ns: int, state: State):
self.mcap_writer.write_message(
topic="/pendulum_state",
message=PendulumState(
motor_angle=state.motor_angle,
pendulum_angle=state.pendulum_angle,
motor_angle_velocity=state.motor_angle_velocity,
pendulum_angle_velocity=state.pendulum_angle_velocity,
reward=state.reward,
action=state.action,
),
log_time=time_ns,
publish_time=time_ns,
)

def load(self) -> tuple[np.ndarray, np.ndarray]:
times: List[float] = list()
states: List[np.ndarray] = list()
for msg in read_protobuf_messages(self.log_path, log_time_order=True):
p = msg.proto_msg
state = np.array(
[
p.motor_angle,
p.pendulum_angle,
p.motor_angle_velocity,
p.pendulum_angle_velocity,
p.reward,
p.action,
]
)
times.append(float(msg.log_time_ns * 1e-9))
states.append(state)
return np.array(times), np.array(states)

def plot(self, times: List[float], states: List[np.ndarray]):
for title, idx in STATE.items():
plt.figure(idx + 1)
plt.plot(times, states[:, idx])
plt.title(title)
plt.show()
47 changes: 19 additions & 28 deletions furuta/rl/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
import gymnasium as gym
import hydra
import numpy as np
import wandb
from gymnasium.spaces import Box
from mcap_protobuf.writer import Writer

import wandb
from furuta.logging.protobuf.pendulum_state_pb2 import PendulumState
from furuta.utils import ALPHA, ALPHA_DOT, THETA, THETA_DOT
from furuta.logger import SimpleLogger
from furuta.utils import ALPHA, ALPHA_DOT, THETA, THETA_DOT, State


class GentlyTerminating(gym.Wrapper):
Expand Down Expand Up @@ -72,7 +71,7 @@ def __init__(
self.use_sim_time = use_sim_time

self.episodes = 0
self.mcap_writer = None
self.logger = None
self.episodic = episodic

def step(self, action):
Expand All @@ -85,19 +84,15 @@ def step(self, action):
else:
time_to_log = time.time_ns()

self.mcap_writer.write_message(
topic="/pendulum_state",
message=PendulumState(
motor_angle=self.unwrapped._state[THETA],
pendulum_angle=self.unwrapped._state[ALPHA],
motor_angle_velocity=self.unwrapped._state[THETA_DOT],
pendulum_angle_velocity=self.unwrapped._state[ALPHA_DOT],
reward=reward,
action=float(action[0]),
),
log_time=time_to_log,
publish_time=time_to_log,
state = State(
motor_angle=self.unwrapped._state[THETA],
pendulum_angle=self.unwrapped._state[ALPHA],
motor_angle_velocity=self.unwrapped._state[THETA_DOT],
pendulum_angle_velocity=self.unwrapped._state[ALPHA_DOT],
reward=reward,
action=float(action[0]),
)
self.logger.update(time_to_log, state)

return observation, reward, terminated, truncated, info

Expand All @@ -112,15 +107,16 @@ def reset(

if self.episodic:
# close previous log file
self.close_mcap_writer()
if self.logger is not None:
self.logger.stop()
fname = f"ep{self.episodes}_{time.strftime('%Y%m%d-%H%M%S')}.mcap"
else:
fname = f"{time.strftime('%Y%m%d-%H%M%S')}.mcap"

if self.mcap_writer is None or self.episodic:
# instantiate a new MCAP writer
self.output_file = open(self.log_dir / fname, "wb")
self.mcap_writer = Writer(self.output_file)
if self.logger is None or self.episodic:
# instantiate a new MCAP logger
self.logger = SimpleLogger(self.log_dir / fname)
self.logger.start()

# TODO add metadata?
# date, control frequency, wandb run id, sim parameters, robot parameters, etc.
Expand All @@ -133,14 +129,9 @@ def reset(
return self.env.reset(seed=seed, options=options)

def close(self):
self.close_mcap_writer()
self.logger.stop()
return self.env.close()

def close_mcap_writer(self):
if self.mcap_writer is not None:
self.mcap_writer.finish()
self.output_file.close()


class ControlFrequency(gym.Wrapper):
"""Enforce a sleeping time (dt) between each step."""
Expand Down
22 changes: 21 additions & 1 deletion furuta/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,33 @@
# https://git.ias.informatik.tu-darmstadt.de/quanser/clients/-/blob/v0.1.1/quanser_robots/common.py
from dataclasses import dataclass

import numpy as np
from gymnasium import spaces
from scipy import signal

THETA = 0
ALPHA = 1
THETA_DOT = 2
ALPHA_DOT = 3

STATE = {
"motor_angle": 0,
"pendulum_angle": 1,
"motor_angle_velocity": 2,
"pendulum_angle_velocity": 3,
"reward": 4,
"action": 5,
}


@dataclass
class State:
motor_angle: float = 0.0
pendulum_angle: float = 0.0
motor_angle_velocity: float = 0.0
pendulum_angle_velocity: float = 0.0
reward: float = 0.0
action: float = 0.0


class VelocityFilter:
"""Discrete velocity filter derived from a continuous one."""
Expand Down
31 changes: 13 additions & 18 deletions furuta/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from panda3d_viewer import Viewer
from pinocchio.visualize.panda3d_visualizer import Panda3dVisualizer

from furuta.logger import SimpleLogger
from furuta.utils import ALPHA, THETA


Expand All @@ -26,6 +25,18 @@ def display(cls, state: np.ndarray) -> (np.ndarray | None):
def close(cls):
pass

def animate(cls, times: np.ndarray, states: np.ndarray):
# Initial state
q = states[:, :2]
cls.display(q[0])
time.sleep(1.0)
tic = time.time()
for i in range(1, len(times)):
toc = time.time()
time.sleep(max(0, times[i] - times[i - 1] - (toc - tic)))
tic = time.time()
cls.display(q[i])


class Viewer3D(AbstractViewer):
def __init__(cls, robot: pin.RobotWrapper = None):
Expand All @@ -43,25 +54,9 @@ def display(cls, state: np.ndarray) -> np.ndarray:
def close(cls):
cls.viewer.close()

def animate(cls, times: np.ndarray, states: np.ndarray):
# Initial state
q = np.array(states)[:, :2]
cls.display(q[0])
time.sleep(1.0)
tic = time.time()
for i in range(1, len(times)):
toc = time.time()
time.sleep(max(0, times[i] - times[i - 1] - (toc - tic)))
tic = time.time()
if i % 10 == 0:
cls.display(q[i])

def animate_log(cls, log: SimpleLogger):
cls.animate(log.times, log.states)


class Viewer2D(AbstractViewer):
def __init__(cls, render_fps: int = 50, render_mode: str = None):
def __init__(cls, render_fps: int = 30, render_mode: str = "human"):
cls.render_fps = render_fps
cls.render_mode = render_mode

Expand Down
27 changes: 27 additions & 0 deletions scripts/replay_log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import argparse

from furuta.logger import SimpleLogger

# from furuta.robot import RobotModel
from furuta.viewer import Viewer2D # , Viewer3D

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-t", "--type", required=False, default="2D", choices=("2D", "3D"))
parser.add_argument("-f", "--file_path", required=True)
args = parser.parse_args()

logger = SimpleLogger(args.file_path)
times, states = logger.load()

if args.type == "2D":
viewer = Viewer2D()
else:
print("3D viewer is not supported yet")
assert False
# viewer = Viewer3D(RobotModel.robot)

viewer.animate(times, states)
viewer.close()

logger.plot(times, states)
31 changes: 0 additions & 31 deletions scripts/replay_mcap.py

This file was deleted.

Loading

0 comments on commit 4e2ecc2

Please sign in to comment.