From be400fad592e84852eec74f1ab69bd627a6c9cdf Mon Sep 17 00:00:00 2001 From: Amit Moryossef Date: Sat, 30 Mar 2024 15:38:33 +0100 Subject: [PATCH] fix(generic): support openpose --- src/python/pose_format/bin/pose_visualizer.py | 16 ++++++++++++-- src/python/pose_format/utils/generic.py | 22 +++++++++++++++++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/python/pose_format/bin/pose_visualizer.py b/src/python/pose_format/bin/pose_visualizer.py index 7faae41..c108a02 100644 --- a/src/python/pose_format/bin/pose_visualizer.py +++ b/src/python/pose_format/bin/pose_visualizer.py @@ -3,14 +3,25 @@ import argparse import os +import numpy as np from pose_format import Pose from pose_format.pose_visualizer import PoseVisualizer +from pose_format.utils.generic import pose_normalization_info -def visualize_pose(pose_path: str, video_path: str): +def visualize_pose(pose_path: str, video_path: str, normalize=False): with open(pose_path, "rb") as f: pose = Pose.read(f.read()) + if normalize: + pose = pose.normalize(pose_normalization_info(pose.header)) + + new_width = 500 + shift = 1.25 + shift_vec = np.full(shape=(pose.body.data.shape[-1]), fill_value=shift, dtype=np.float32) + pose.body.data = (pose.body.data + shift_vec) * new_width + pose.header.dimensions.height = pose.header.dimensions.width = int(new_width * shift * 2) + v = PoseVisualizer(pose) v.save_video(video_path, v.draw()) @@ -20,10 +31,11 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument('-i', required=True, type=str, help='path to input pose file') parser.add_argument('-o', required=True, type=str, help='path to output video file') + parser.add_argument('--normalize', action='store_true', help='Normalize pose before visualization') args = parser.parse_args() if not os.path.exists(args.i): raise FileNotFoundError(f"Pose file {args.i} not found") - visualize_pose(args.i, args.o) + visualize_pose(args.i, args.o, args.normalize) diff --git a/src/python/pose_format/utils/generic.py b/src/python/pose_format/utils/generic.py index edce8e4..e30c936 100644 --- a/src/python/pose_format/utils/generic.py +++ b/src/python/pose_format/utils/generic.py @@ -113,12 +113,30 @@ def fake_pose(num_frames: int, fps=25, dims=2, components=OpenPose_Components): return Pose(header, body) +def get_hand_wrist_index(pose: Pose, hand: str): + if pose.header.components[0].name == "POSE_LANDMARKS": + return pose.header._get_point_index(f'{hand.upper()}_HAND_LANDMARKS', 'WRIST') + elif pose.header.components[0].name == "pose_keypoints_2d": + return pose.header._get_point_index(f'hand_{hand.lower()}_keypoints_2d', 'BASE') + else: + raise ValueError("Unknown pose header schema for get_hand_wrist_index") + + +def get_body_hand_wrist_index(pose: Pose, hand: str): + if pose.header.components[0].name == "POSE_LANDMARKS": + return pose.header._get_point_index('POSE_LANDMARKS', f'{hand.upper()}_WRIST') + elif pose.header.components[0].name == "pose_keypoints_2d": + return pose.header._get_point_index('pose_keypoints_2d', f'{hand.upper()[0]}Wrist') + else: + raise ValueError("Unknown pose header schema for get_hand_wrist_index") + + def correct_wrist(pose: Pose, hand: str) -> Pose: - wrist_index = pose.header._get_point_index(f'{hand}_HAND_LANDMARKS', 'WRIST') + wrist_index = get_hand_wrist_index(pose, hand) wrist = pose.body.data[:, :, wrist_index] wrist_conf = pose.body.confidence[:, :, wrist_index] - body_wrist_index = pose.header._get_point_index('POSE_LANDMARKS', f'{hand}_WRIST') + body_wrist_index = get_body_hand_wrist_index(pose, hand) body_wrist = pose.body.data[:, :, body_wrist_index] body_wrist_conf = pose.body.confidence[:, :, body_wrist_index]