Skip to content

Commit

Permalink
fix(generic): support openpose
Browse files Browse the repository at this point in the history
  • Loading branch information
AmitMY committed Mar 30, 2024
1 parent a376ff8 commit be400fa
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
16 changes: 14 additions & 2 deletions src/python/pose_format/bin/pose_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,25 @@
import argparse
import os

import numpy as np
from pose_format import Pose
from pose_format.pose_visualizer import PoseVisualizer
from pose_format.utils.generic import pose_normalization_info


def visualize_pose(pose_path: str, video_path: str):
def visualize_pose(pose_path: str, video_path: str, normalize=False):
with open(pose_path, "rb") as f:
pose = Pose.read(f.read())

if normalize:
pose = pose.normalize(pose_normalization_info(pose.header))

new_width = 500
shift = 1.25
shift_vec = np.full(shape=(pose.body.data.shape[-1]), fill_value=shift, dtype=np.float32)
pose.body.data = (pose.body.data + shift_vec) * new_width
pose.header.dimensions.height = pose.header.dimensions.width = int(new_width * shift * 2)

v = PoseVisualizer(pose)

v.save_video(video_path, v.draw())
Expand All @@ -20,10 +31,11 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument('-i', required=True, type=str, help='path to input pose file')
parser.add_argument('-o', required=True, type=str, help='path to output video file')
parser.add_argument('--normalize', action='store_true', help='Normalize pose before visualization')

args = parser.parse_args()

if not os.path.exists(args.i):
raise FileNotFoundError(f"Pose file {args.i} not found")

visualize_pose(args.i, args.o)
visualize_pose(args.i, args.o, args.normalize)
22 changes: 20 additions & 2 deletions src/python/pose_format/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,30 @@ def fake_pose(num_frames: int, fps=25, dims=2, components=OpenPose_Components):
return Pose(header, body)


def get_hand_wrist_index(pose: Pose, hand: str):
if pose.header.components[0].name == "POSE_LANDMARKS":
return pose.header._get_point_index(f'{hand.upper()}_HAND_LANDMARKS', 'WRIST')
elif pose.header.components[0].name == "pose_keypoints_2d":
return pose.header._get_point_index(f'hand_{hand.lower()}_keypoints_2d', 'BASE')
else:
raise ValueError("Unknown pose header schema for get_hand_wrist_index")


def get_body_hand_wrist_index(pose: Pose, hand: str):
if pose.header.components[0].name == "POSE_LANDMARKS":
return pose.header._get_point_index('POSE_LANDMARKS', f'{hand.upper()}_WRIST')
elif pose.header.components[0].name == "pose_keypoints_2d":
return pose.header._get_point_index('pose_keypoints_2d', f'{hand.upper()[0]}Wrist')
else:
raise ValueError("Unknown pose header schema for get_hand_wrist_index")


def correct_wrist(pose: Pose, hand: str) -> Pose:
wrist_index = pose.header._get_point_index(f'{hand}_HAND_LANDMARKS', 'WRIST')
wrist_index = get_hand_wrist_index(pose, hand)
wrist = pose.body.data[:, :, wrist_index]
wrist_conf = pose.body.confidence[:, :, wrist_index]

body_wrist_index = pose.header._get_point_index('POSE_LANDMARKS', f'{hand}_WRIST')
body_wrist_index = get_body_hand_wrist_index(pose, hand)
body_wrist = pose.body.data[:, :, body_wrist_index]
body_wrist_conf = pose.body.confidence[:, :, body_wrist_index]

Expand Down

0 comments on commit be400fa

Please sign in to comment.