Skip to content

Commit

Permalink
Preprocessing for poses, and some type annotations, and a bit of refa…
Browse files Browse the repository at this point in the history
…ctoring/renaming
  • Loading branch information
cleong110 committed Jan 7, 2025
1 parent 1f5767d commit 07225cb
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 27 deletions.
4 changes: 2 additions & 2 deletions pose_evaluation/metrics/distance_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from pose_format import Pose

from pose_evaluation.metrics.base_pose_metric import PoseMetric

ValidDistanceKinds = Literal["euclidean", "manhattan"]

class DistanceMetric(PoseMetric):
def __init__(self, kind: Literal["l1", "l2"] = "l2"):
def __init__(self, kind: ValidDistanceKinds = "euclidean"):
super().__init__(f"DistanceMetric {kind}", higher_is_better=False)
self.kind = kind

Expand Down
9 changes: 6 additions & 3 deletions pose_evaluation/metrics/ham2pose_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
# and then code was copied to this repo by @cleong110

import os
from pathlib import Path
import numpy as np
from scipy.spatial.distance import euclidean
from fastdtw import fastdtw

from pose_evaluation.utils.pose_utils import get_preprocessed_pose, pose_hide_low_conf
from pose_evaluation.utils.pose_utils import preprocess_pose, load_pose_file, pose_hide_low_conf

# rootdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
# sys.path.insert(0, rootdir)
Expand Down Expand Up @@ -98,8 +99,10 @@ def APE(trajectory1, trajectory2):


def compare_pose_videos(pose1_id, pose2_id, keypoints_path, distance_function=fastdtw):
pose1 = get_preprocessed_pose(os.path.join(keypoints_path, pose1_id), pose1_id)
pose2 = get_preprocessed_pose(os.path.join(keypoints_path, pose2_id), pose2_id)
pose1 = load_pose_file(Path(keypoints_path/ pose1_id), pose1_id)
pose1 = preprocess_pose(os.path.join(keypoints_path, pose1_id), pose1_id)
pose2 = preprocess_pose(os.path.join(keypoints_path, pose2_id), pose2_id)
pose2 = load_pose_file(Path(keypoints_path/ pose2_id), pose2_id)
return compare_poses(pose1, pose2, distance_function=distance_function)


Expand Down
18 changes: 15 additions & 3 deletions pose_evaluation/metrics/ndtw_mje.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,25 @@

from pose_format import Pose

from pose_evaluation.metrics.distance_metric import DistanceMetric
from pose_evaluation.utils.pose_utils import pose_hide_low_conf, get_preprocessed_pose
from pose_evaluation.metrics.distance_metric import DistanceMetric, ValidDistanceKinds
from pose_evaluation.utils.pose_utils import pose_hide_low_conf, preprocess_pose

class DynamicTimeWarpingMeanJointError(DistanceMetric):
def __init__(self, kind: Literal["manhattan", "euclidean"] = "euclidean", normalize_missing:bool=True):
def __init__(self, kind: ValidDistanceKinds = "euclidean",
normalize_poses:bool=True,
reduce_poses:bool=False,
remove_legs:bool=True,
remove_world_landmarks:bool=False,
conf_threshold_to_drop_points:None|int=None,
):
super().__init__(kind)

self.normalize_poses = normalize_poses
self.reduce_reference_poses = reduce_poses
self.remove_legs = remove_legs
self.remove_world_landmarks = remove_world_landmarks
self.conf_threshold_to_drop_points = conf_threshold_to_drop_points

def score_all(self, hypotheses:List[Pose], references:List[Pose], progress_bar=True):
# TODO:
return super().score_all(hypotheses, references, progress_bar)
Expand Down
71 changes: 52 additions & 19 deletions pose_evaluation/utils/pose_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from pose_format.utils.generic import pose_normalization_info, pose_hide_legs


def remove_world_landmarks(pose: Pose):
remove_specified_landmarks(pose, "POSE_WORLD_LANDMARKS")
def pose_remove_world_landmarks(pose: Pose):
return remove_specified_landmarks(pose, "POSE_WORLD_LANDMARKS")


def remove_specified_landmarks(pose: Pose, landmark_names: List[str]):
Expand All @@ -18,6 +18,7 @@ def remove_specified_landmarks(pose: Pose, landmark_names: List[str]):
new_pose = pose.get_components(components_without_specified_names)
pose.body = new_pose.body
pose.header = new_pose.header
return new_pose


def get_chosen_components_from_pose(
Expand All @@ -36,33 +37,65 @@ def get_face_and_hands_from_pose(pose: Pose) -> Pose:
return pose.get_components(components_to_keep)


def get_preprocessed_pose(pose_path: Path | str) -> Pose:

def load_pose_file(pose_path: Path) -> Pose:
pose_path = Path(pose_path).resolve()

with pose_path.open("rb") as f:
pose = Pose.read(f.read())
return pose


def reduce_pose_components_to_intersection(poses: List[Pose]) -> List[Pose]:
component_names = [pose.header.components for pose in poses]
set_of_common_components = list(set.intersection(*component_names))
poses = [pose.get_components(set_of_common_components) for pose in poses]


def preprocess_poses(
poses: List[Pose],
normalize_poses: bool = True,
reduce_poses_to_common_components: bool = False,
remove_legs: bool = True,
remove_world_landmarks: bool = False,
conf_threshold_to_drop_points: None | float = None,
) -> List[Pose]:
# NOTE: this is a lot of arguments. Perhaps a list may be better?
if reduce_poses_to_common_components:
reduce_pose_components_to_intersection(poses)

poses = [
preprocess_pose(
pose,
normalize_poses=normalize_poses,
remove_legs=remove_legs,
remove_world_landmarks=remove_world_landmarks,
conf_threshold_to_drop_points=conf_threshold_to_drop_points,
)
for pose in poses
]
return poses


# normalize
# TODO: confirm this is correct. Previously used pose_format.utils.generic.pose_normalization_info(),
# but pose-format README advises using pose.normalize_distribution()
pose.normalize(pose_normalization_info(pose))
def preprocess_pose(
pose: Pose,
normalize_poses: bool = True,
remove_legs: bool = True,
remove_world_landmarks: bool = False,
conf_threshold_to_drop_points: None | int = None,
) -> Pose:
if normalize_poses:
# note: latest version (not yet released) does it automatically
pose = pose.normalize(pose_normalization_info(pose))

# Drop legs
pose_hide_legs(pose)
if remove_legs:
pose_hide_legs(pose)

# not used, typically.
remove_world_landmarks(pose)
if remove_world_landmarks:
pose_remove_world_landmarks(pose)

# hide low conf
pose_hide_low_conf(pose)

pose.focus()

# TODO: prune leading/trailing frames without useful data (e.g. no hands, only zeroes, almost no face)
for frame_index in enumerate(pose.body.data):
# https://github.com/rotem-shalev/Ham2Pose/blob/main/metrics.py#L44-L60
pass
pose_hide_low_conf(pose, confidence_threshold=conf_threshold_to_drop_points)

return pose

Expand Down

0 comments on commit 07225cb

Please sign in to comment.