From f319e3e86c44f7297be12c996c2babd912ae850d Mon Sep 17 00:00:00 2001 From: cleong110 <122366389+cleong110@users.noreply.github.com> Date: Tue, 14 Jan 2025 16:18:12 -0500 Subject: [PATCH] Feature/detect format function (#144) * CDL: minor doc typo fix * Undoing some changes that got mixed in * Add detect_pose_format function and SupportedPoseFormat Literal * detect_known_pose_format and tests for it. * various cleanup changes, style changes * missing import * undo black formatting for face contours and ignore_names * SupportedPoseFormat->KnownPoseFormat * Unreachable raise ValueErrors fixed * generic utils type annotations * change detect_known_format to take Pose or PoseHeader * Reraise ImportError if mediapipe is not installed * conftest update to supply unknown-format fake poses * nicer formatting for plane_info and line_info * fix import in generic_test.py * add some pylint disables, consistent with pose-evaluation * Change import in conftest.py * change import style in generic.py * change more imports * Fix a few type issues * Change matrix strategy fail-fast to false, so that we can still run tests if Python 3.8 does not work * Union for type annotation backwards compatibility * Add checks for NotImplementedError * Fix correct_wrist modifying input, and wrong shape for stacked conf. Also added a function to check fake_pose and its outputs * Simplify get_component_names and fix spacing * fix test_get_component_names --- .github/workflows/python.yaml | 1 + .gitignore | 4 + src/python/ComfyUI-Pose-Format/nodes.py | 2 +- src/python/pose_format/bin/pose_visualizer.py | 2 +- src/python/pose_format/pose.py | 6 +- src/python/pose_format/utils/conftest.py | 23 ++ src/python/pose_format/utils/generic.py | 226 +++++++++++++----- src/python/pose_format/utils/generic_test.py | 213 +++++++++++++++++ src/python/pose_format/utils/openpose_135.py | 2 +- .../pose_format/utils/pose_converter.py | 2 +- src/python/pyproject.toml | 10 + src/python/tests/visualization_test.py | 2 +- 12 files changed, 419 insertions(+), 74 deletions(-) create mode 100644 src/python/pose_format/utils/conftest.py create mode 100644 src/python/pose_format/utils/generic_test.py diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 16856e4..db0becf 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -13,6 +13,7 @@ jobs: strategy: matrix: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + fail-fast: false steps: - uses: actions/checkout@v3 diff --git a/.gitignore b/.gitignore index f32e31a..63a48a6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,6 @@ .idea/ .DS_Store +.vscode/ +.coverage +.coveragerc +coverage.lcov \ No newline at end of file diff --git a/src/python/ComfyUI-Pose-Format/nodes.py b/src/python/ComfyUI-Pose-Format/nodes.py index 33483e6..01401fe 100644 --- a/src/python/ComfyUI-Pose-Format/nodes.py +++ b/src/python/ComfyUI-Pose-Format/nodes.py @@ -2,7 +2,7 @@ import cv2 import torch -from pose_format import Pose +from pose_format.pose import Pose from pose_format.pose_visualizer import PoseVisualizer from pose_format.utils.generic import reduce_holistic from pose_format.utils.openpose import OpenPose_Components diff --git a/src/python/pose_format/bin/pose_visualizer.py b/src/python/pose_format/bin/pose_visualizer.py index 80cbe8e..8f30618 100644 --- a/src/python/pose_format/bin/pose_visualizer.py +++ b/src/python/pose_format/bin/pose_visualizer.py @@ -3,7 +3,7 @@ import argparse import os -from pose_format import Pose +from pose_format.pose import Pose from pose_format.pose_visualizer import PoseVisualizer from pose_format.utils.generic import pose_normalization_info diff --git a/src/python/pose_format/pose.py b/src/python/pose_format/pose.py index 664ced0..06dbfab 100644 --- a/src/python/pose_format/pose.py +++ b/src/python/pose_format/pose.py @@ -1,5 +1,5 @@ from itertools import chain -from typing import BinaryIO, Dict, List, Tuple, Type +from typing import BinaryIO, Dict, List, Tuple, Type, Union import numpy as np import numpy.ma as ma @@ -87,7 +87,7 @@ def focus(self): dimensions = (maxs - mins).tolist() self.header.dimensions = PoseHeaderDimensions(*dimensions) - def normalize(self, info: PoseNormalizationInfo|None=None, scale_factor: float = 1) -> "Pose": + def normalize(self, info: Union[PoseNormalizationInfo,None]=None, scale_factor: float = 1) -> "Pose": """ Normalize the points to a fixed distance between two particular points. @@ -203,7 +203,7 @@ def frame_dropout_normal(self, dropout_mean: float = 0.5, dropout_std: float = 0 body, selected_indexes = self.body.frame_dropout_normal(dropout_mean=dropout_mean, dropout_std=dropout_std) return Pose(header=self.header, body=body), selected_indexes - def get_components(self, components: List[str], points: Dict[str, List[str]] = None): + def get_components(self, components: List[str], points: Union[Dict[str, List[str]],None] = None): """ get pose components based on criteria. diff --git a/src/python/pose_format/utils/conftest.py b/src/python/pose_format/utils/conftest.py new file mode 100644 index 0000000..4f4ae30 --- /dev/null +++ b/src/python/pose_format/utils/conftest.py @@ -0,0 +1,23 @@ +import copy +from typing import List, get_args +import pytest +from pose_format.pose import Pose +from pose_format.utils.generic import get_standard_components_for_known_format, fake_pose, KnownPoseFormat + +@pytest.fixture +def fake_poses(request) -> List[Pose]: + # Access the parameter passed to the fixture + known_format = request.param + count = getattr(request, "count", 3) + known_formats = get_args(KnownPoseFormat) + if known_format in known_formats: + + components = get_standard_components_for_known_format(known_format) + return copy.deepcopy([fake_pose(i * 10 + 10, components=components) for i in range(count)]) + else: + # get openpose + fake_poses_list = [fake_pose(i * 10 + 10) for i in range(count)] + for i, pose in enumerate(fake_poses_list): + for component in pose.header.components: + component.name = f"unknown_component_{i}_formerly_{component.name}" + return copy.deepcopy(fake_poses_list) diff --git a/src/python/pose_format/utils/generic.py b/src/python/pose_format/utils/generic.py index 4a06c13..940d43a 100644 --- a/src/python/pose_format/utils/generic.py +++ b/src/python/pose_format/utils/generic.py @@ -1,12 +1,57 @@ -from typing import Tuple - +from pathlib import Path +from typing import Tuple, Literal, List, Union +import copy import numpy as np from numpy import ma -from pose_format import Pose +from pose_format.pose import Pose from pose_format.numpy import NumPyPoseBody -from pose_format.pose_header import PoseHeader, PoseHeaderDimensions +from pose_format.pose_header import PoseHeader, PoseHeaderDimensions, PoseHeaderComponent, PoseNormalizationInfo from pose_format.utils.normalization_3d import PoseNormalizer from pose_format.utils.openpose import OpenPose_Components +from pose_format.utils.openpose_135 import OpenPose_Components as OpenPose135_Components + +# from pose_format.utils.holistic import holistic_components +# The import above creates an error: ImportError: Please install mediapipe with: pip install mediapipe + +KnownPoseFormat = Literal["holistic", "openpose", "openpose_135"] + + +def get_component_names( + pose_or_header_or_components: Union[Pose,PoseHeader]) -> List[str]: + if isinstance(pose_or_header_or_components, Pose): + return [c.name for c in pose_or_header_or_components.header.components] + if isinstance(pose_or_header_or_components, PoseHeader): + return [c.name for c in pose_or_header_or_components.components] + raise ValueError(f"Could not get component_names from {pose_or_header_or_components}") + + +def detect_known_pose_format(pose_or_header: Union[Pose,PoseHeader]) -> KnownPoseFormat: + component_names= get_component_names(pose_or_header) + + # would be better to import from pose_format.utils.holistic but that creates a dep on mediapipe + mediapipe_components = [ + "POSE_LANDMARKS", + "FACE_LANDMARKS", + "LEFT_HAND_LANDMARKS", + "RIGHT_HAND_LANDMARKS", + "POSE_WORLD_LANDMARKS", + ] + + openpose_components = [c.name for c in OpenPose_Components] + + openpose_135_components = [c.name for c in OpenPose135_Components] + + for component_name in component_names: + if component_name in mediapipe_components: + return "holistic" + if component_name in openpose_components: + return "openpose" + if component_name in openpose_135_components: + return "openpose_135" + + raise ValueError( + f"Could not detect pose format, unknown pose header schema with component names: {component_names}" + ) def normalize_pose_size(pose: Pose, target_width: int = 512): @@ -18,7 +63,8 @@ def normalize_pose_size(pose: Pose, target_width: int = 512): def pose_hide_legs(pose: Pose): - if pose.header.components[0].name == "POSE_LANDMARKS": + known_pose_format = detect_known_pose_format(pose) + if known_pose_format == "holistic": point_names = ["KNEE", "ANKLE", "HEEL", "FOOT_INDEX"] # pylint: disable=protected-access points = [ @@ -28,70 +74,90 @@ def pose_hide_legs(pose: Pose): ] pose.body.data[:, :, points, :] = 0 pose.body.confidence[:, :, points] = 0 - elif pose.header.components[0].name == "pose_keypoints_2d": + elif known_pose_format == "openpose": point_names = ["Hip", "Knee", "Ankle", "BigToe", "SmallToe", "Heel"] # pylint: disable=protected-access points = [ - pose.header._get_point_index("pose_keypoints_2d", side + n) - for n in point_names - for side in ["L", "R"] + pose.header._get_point_index("pose_keypoints_2d", side + n) for n in point_names for side in ["L", "R"] ] pose.body.data[:, :, points, :] = 0 pose.body.confidence[:, :, points] = 0 else: - raise ValueError("Unknown pose header schema for hiding legs") + raise NotImplementedError( + f"Unsupported pose header schema {known_pose_format} for {pose_hide_legs.__name__}: {pose.header}" + ) + +def pose_shoulders(pose_header: PoseHeader) -> Tuple[Tuple[str, str], Tuple[str, str]]: + known_pose_format = detect_known_pose_format(pose_header) -def pose_shoulders(pose_header: PoseHeader): - if pose_header.components[0].name == "POSE_LANDMARKS": + if known_pose_format == "holistic": return ("POSE_LANDMARKS", "RIGHT_SHOULDER"), ("POSE_LANDMARKS", "LEFT_SHOULDER") - if pose_header.components[0].name == "BODY_135": + if known_pose_format == "openpose_135": return ("BODY_135", "RShoulder"), ("BODY_135", "LShoulder") - if pose_header.components[0].name == "pose_keypoints_2d": + if known_pose_format == "openpose": return ("pose_keypoints_2d", "RShoulder"), ("pose_keypoints_2d", "LShoulder") - raise ValueError("Unknown pose header schema for normalization") + raise NotImplementedError( + f"Unsupported pose header schema {known_pose_format} for {pose_shoulders.__name__}: {pose_header}" + ) -def hands_indexes(pose_header: PoseHeader): - if pose_header.components[0].name == "POSE_LANDMARKS": - return [pose_header._get_point_index("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"), - pose_header._get_point_index("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP")] +def hands_indexes(pose_header: PoseHeader)-> List[int]: + known_pose_format = detect_known_pose_format(pose_header) + if known_pose_format == "holistic": + return [ + pose_header._get_point_index("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"), + pose_header._get_point_index("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"), + ] - if pose_header.components[0].name == "pose_keypoints_2d": - return [pose_header._get_point_index("hand_left_keypoints_2d", "M_CMC"), - pose_header._get_point_index("hand_right_keypoints_2d", "M_CMC")] + if known_pose_format == "openpose": + return [ + pose_header._get_point_index("hand_left_keypoints_2d", "M_CMC"), + pose_header._get_point_index("hand_right_keypoints_2d", "M_CMC"), + ] + raise NotImplementedError( + f"Unsupported pose header schema {known_pose_format} for {hands_indexes.__name__}: {pose_header}" + ) -def pose_normalization_info(pose_header: PoseHeader): +def pose_normalization_info(pose_header: PoseHeader) ->PoseNormalizationInfo: (c1, p1), (c2, p2) = pose_shoulders(pose_header) return pose_header.normalization_info(p1=(c1, p1), p2=(c2, p2)) -def hands_components(pose_header: PoseHeader): - if pose_header.components[0].name in ["POSE_LANDMARKS", "LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS"]: - return ("LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS"), \ - ("WRIST", "PINKY_MCP", "INDEX_FINGER_MCP"), \ - ("WRIST", "MIDDLE_FINGER_MCP") +def hands_components(pose_header: PoseHeader)-> Tuple[Tuple[str, str], Tuple[str, str, str], Tuple[str, str]]: + known_pose_format = detect_known_pose_format(pose_header) + if known_pose_format == "holistic": + return ( + ("LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS"), + ("WRIST", "PINKY_MCP", "INDEX_FINGER_MCP"), + ("WRIST", "MIDDLE_FINGER_MCP"), + ) - if pose_header.components[0].name in ["pose_keypoints_2d", "hand_left_keypoints_2d", "hand_right_keypoints_2d"]: - return ("hand_left_keypoints_2d", "hand_right_keypoints_2d"), \ - ("BASE", "P_CMC", "I_CMC"), \ - ("BASE", "M_CMC") + if known_pose_format == "openpose": + return ("hand_left_keypoints_2d", "hand_right_keypoints_2d"), ("BASE", "P_CMC", "I_CMC"), ("BASE", "M_CMC") - raise ValueError("Unknown pose header") + raise NotImplementedError( + f"Unsupported pose header schema '{known_pose_format}' for {hands_components.__name__}: {pose_header}" + ) def normalize_component_3d(pose, component_name: str, plane: Tuple[str, str, str], line: Tuple[str, str]): hand_pose = pose.get_components([component_name]) - plane = hand_pose.header.normalization_info(p1=(component_name, plane[0]), - p2=(component_name, plane[1]), - p3=(component_name, plane[2])) - line = hand_pose.header.normalization_info(p1=(component_name, line[0]), - p2=(component_name, line[1])) - normalizer = PoseNormalizer(plane=plane, line=line) + plane_info = hand_pose.header.normalization_info( + p1=(component_name, plane[0]), + p2=(component_name, plane[1]), + p3=(component_name, plane[2]) + ) + line_info = hand_pose.header.normalization_info( + p1=(component_name, line[0]), + p2=(component_name, line[1]) + ) + + normalizer = PoseNormalizer(plane=plane_info, line=line_info) normalized_hand = normalizer(hand_pose.body.data) # Add normalized hand to pose @@ -107,39 +173,68 @@ def normalize_hands_3d(pose: Pose, left_hand=True, right_hand=True): normalize_component_3d(pose, right_hand_component, plane, line) -def fake_pose(num_frames: int, fps=25, dims=2, components=OpenPose_Components): - dimensions = PoseHeaderDimensions(width=1, height=1, depth=1) +def get_standard_components_for_known_format(known_pose_format: KnownPoseFormat) -> List[PoseHeaderComponent]: + if known_pose_format == "holistic": + try: + import pose_format.utils.holistic as holistic_utils + return holistic_utils.holistic_components() + except ImportError as e: + raise e + if known_pose_format == "openpose": + return OpenPose_Components + if known_pose_format == "openpose_135": + return OpenPose135_Components + + raise NotImplementedError(f"Unsupported pose header schema {known_pose_format}") + + +def fake_pose(num_frames: int, fps: int=25, components: Union[List[PoseHeaderComponent],None]=None)->Pose: + if components is None: + components = copy.deepcopy(OpenPose_Components) # fixes W0102, dangerous default value + + if components[0].format == "XYZC": + dimensions = PoseHeaderDimensions(width=1, height=1, depth=1) + elif components[0].format == "XYC": + dimensions = PoseHeaderDimensions(width=1, height=1) + else: + raise ValueError(f"Unknown point format: {components[0].format}") header = PoseHeader(version=0.2, dimensions=dimensions, components=components) total_points = header.total_points() - data = np.random.randn(num_frames, 1, total_points, dims) + data = np.random.randn(num_frames, 1, total_points, header.num_dims()) confidence = np.random.randn(num_frames, 1, total_points) masked_data = ma.masked_array(data) + body = NumPyPoseBody(fps=int(fps), data=masked_data, confidence=confidence) return Pose(header, body) -def get_hand_wrist_index(pose: Pose, hand: str): - if pose.header.components[0].name == "POSE_LANDMARKS": - return pose.header._get_point_index(f'{hand.upper()}_HAND_LANDMARKS', 'WRIST') - elif pose.header.components[0].name == "pose_keypoints_2d": - return pose.header._get_point_index(f'hand_{hand.lower()}_keypoints_2d', 'BASE') - else: - raise ValueError("Unknown pose header schema for get_hand_wrist_index") +def get_hand_wrist_index(pose: Pose, hand: str)-> int: + known_pose_format = detect_known_pose_format(pose) + if known_pose_format == "holistic": + return pose.header._get_point_index(f"{hand.upper()}_HAND_LANDMARKS", "WRIST") + if known_pose_format == "openpose": + return pose.header._get_point_index(f"hand_{hand.lower()}_keypoints_2d", "BASE") + raise NotImplementedError( + f"Unsupported pose header schema {known_pose_format} for {get_hand_wrist_index.__name__}: {pose.header}" + ) -def get_body_hand_wrist_index(pose: Pose, hand: str): - if pose.header.components[0].name == "POSE_LANDMARKS": - return pose.header._get_point_index('POSE_LANDMARKS', f'{hand.upper()}_WRIST') - elif pose.header.components[0].name == "pose_keypoints_2d": - return pose.header._get_point_index('pose_keypoints_2d', f'{hand.upper()[0]}Wrist') - else: - raise ValueError("Unknown pose header schema for get_hand_wrist_index") +def get_body_hand_wrist_index(pose: Pose, hand: str)-> int: + known_pose_format = detect_known_pose_format(pose) + if known_pose_format == "holistic": + return pose.header._get_point_index("POSE_LANDMARKS", f"{hand.upper()}_WRIST") + if known_pose_format == "openpose": + return pose.header._get_point_index("pose_keypoints_2d", f"{hand.upper()[0]}Wrist") + raise NotImplementedError( + f"Unsupported pose header schema {known_pose_format} for {get_body_hand_wrist_index.__name__}: {pose.header}" + ) def correct_wrist(pose: Pose, hand: str) -> Pose: + pose = copy.deepcopy(pose) # was previously modifying the input wrist_index = get_hand_wrist_index(pose, hand) wrist = pose.body.data[:, :, wrist_index] wrist_conf = pose.body.confidence[:, :, wrist_index] @@ -148,7 +243,8 @@ def correct_wrist(pose: Pose, hand: str) -> Pose: body_wrist = pose.body.data[:, :, body_wrist_index] body_wrist_conf = pose.body.confidence[:, :, body_wrist_index] - stacked_conf = np.stack([wrist_conf] * 3, axis=-1) + point_coordinate_count = wrist.shape[-1] + stacked_conf = np.stack([wrist_conf] * point_coordinate_count, axis=-1) new_wrist_data = ma.where(stacked_conf == 0, body_wrist, wrist) new_wrist_conf = ma.where(wrist_conf == 0, body_wrist_conf, wrist_conf) @@ -158,13 +254,14 @@ def correct_wrist(pose: Pose, hand: str) -> Pose: def correct_wrists(pose: Pose) -> Pose: - pose = correct_wrist(pose, 'LEFT') - pose = correct_wrist(pose, 'RIGHT') + pose = correct_wrist(pose, "LEFT") + pose = correct_wrist(pose, "RIGHT") return pose def reduce_holistic(pose: Pose) -> Pose: - if pose.header.components[0].name != "POSE_LANDMARKS": + known_pose_format = detect_known_pose_format(pose) + if known_pose_format != "holistic": return pose """ @@ -191,11 +288,8 @@ def reduce_holistic(pose: Pose) -> Pose: "KNEE", "ANKLE", "HEEL", "FOOT_INDEX" # Feet ] - body_component = [c for c in pose.header.components if c.name == 'POSE_LANDMARKS'][0] + body_component = [c for c in pose.header.components if c.name == "POSE_LANDMARKS"][0] body_no_face_no_hands = [p for p in body_component.points if all([i not in p for i in ignore_names])] - components = [c.name for c in pose.header.components if c.name != 'POSE_WORLD_LANDMARKS'] - return pose.get_components(components, { - "FACE_LANDMARKS": face_contours, - "POSE_LANDMARKS": body_no_face_no_hands - }) + components = [c.name for c in pose.header.components if c.name != "POSE_WORLD_LANDMARKS"] + return pose.get_components(components, {"FACE_LANDMARKS": face_contours, "POSE_LANDMARKS": body_no_face_no_hands}) diff --git a/src/python/pose_format/utils/generic_test.py b/src/python/pose_format/utils/generic_test.py new file mode 100644 index 0000000..4ef8070 --- /dev/null +++ b/src/python/pose_format/utils/generic_test.py @@ -0,0 +1,213 @@ +from typing import List, get_args +import numpy as np +import pytest +from pose_format.pose import Pose +from pose_format.pose_header import PoseNormalizationInfo +from pose_format.utils.generic import ( + detect_known_pose_format, + get_component_names, + get_standard_components_for_known_format, + KnownPoseFormat, + pose_hide_legs, + pose_shoulders, + hands_indexes, + normalize_pose_size, + pose_normalization_info, + get_hand_wrist_index, + get_body_hand_wrist_index, + correct_wrists, + hands_components, + fake_pose, +) + +TEST_POSE_FORMATS = list(get_args(KnownPoseFormat)) +TEST_POSE_FORMATS_WITH_UNKNOWN = list(get_args(KnownPoseFormat)) +["unknown"] + +@pytest.mark.parametrize( + "fake_poses, expected_type", [(fmt, fmt) for fmt in TEST_POSE_FORMATS_WITH_UNKNOWN], indirect=["fake_poses"] +) +def test_detect_format(fake_poses, expected_type): + known_formats = get_args(KnownPoseFormat) + for pose in fake_poses: + if expected_type in known_formats: + detected_format = detect_known_pose_format(pose) + assert detected_format == expected_type + else: + with pytest.raises( + ValueError, match="Could not detect pose format, unknown pose header schema with component names:" + ): + detect_known_pose_format(pose) + + +@pytest.mark.parametrize( + "fake_poses, known_pose_format", [(fmt, fmt) for fmt in TEST_POSE_FORMATS], indirect=["fake_poses"] +) +def test_get_component_names(fake_poses: List[Pose], known_pose_format: KnownPoseFormat): + + standard_components_for_format = get_standard_components_for_known_format(known_pose_format) + names_for_standard_components_for_format = sorted([c.name for c in standard_components_for_format]) + for pose in fake_poses: + + names_from_poses = sorted(get_component_names(pose)) + names_from_headers = sorted(get_component_names(pose.header)) + assert names_from_poses == names_from_headers + assert names_for_standard_components_for_format == names_from_headers + with pytest.raises(ValueError, match="Could not get component_names"): + get_component_names(pose.body) # type: ignore + + +@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"]) +def test_pose_hide_legs(fake_poses: List[Pose]): + for pose in fake_poses: + + orig_nonzeros_count = np.count_nonzero(pose.body.data) + + detected_format = detect_known_pose_format(pose) + if detected_format == "openpose_135": + with pytest.raises(NotImplementedError, match="Unsupported pose header schema"): + pose_hide_legs(pose) + return + else: + pose_hide_legs(pose) + new_nonzeros_count = np.count_nonzero(pose.body.data) + + assert orig_nonzeros_count > new_nonzeros_count + + +@pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"]) +def test_pose_shoulders(fake_poses: List[Pose]): + for pose in fake_poses: + shoulders = pose_shoulders(pose.header) + assert len(shoulders) == 2 + assert "RIGHT" in shoulders[0][1] or shoulders[0][1][0] == "R" + assert "LEFT" in shoulders[1][1] or shoulders[1][1][0] == "L" + + +@pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"]) +def test_hands_indexes(fake_poses: List[Pose]): + for pose in fake_poses: + detected_format = detect_known_pose_format(pose) + if detected_format == "openpose_135": + with pytest.raises(NotImplementedError, match="Unsupported pose header schema"): + indices = hands_indexes(pose.header) + else: + indices = hands_indexes(pose.header) + assert len(indices) > 0 + + +@pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"]) +def test_normalize_pose_size(fake_poses: List[Pose]): + for pose in fake_poses: + normalize_pose_size(pose) + # TODO: more tests, compare with test data + + +@pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"]) +def test_pose_normalization_info(fake_poses: List[Pose]): + for pose in fake_poses: + info = pose_normalization_info(pose.header) + assert isinstance(info, PoseNormalizationInfo) + assert info.p1 is not None + assert info.p2 is not None + assert info.p3 is None + # TODO: more tests, compare with test data + + +@pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"]) +def test_get_hand_wrist_index(fake_poses: List[Pose]): + for pose in fake_poses: + detected_format = detect_known_pose_format(pose) + for hand in ["LEFT", "RIGHT"]: + if detected_format == "openpose_135": + with pytest.raises(NotImplementedError, match="Unsupported pose header schema"): + index = get_hand_wrist_index(pose, hand) + else: + index = get_hand_wrist_index(pose, hand) + + # TODO: what are the expected values? + + +@pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"]) +def test_get_body_hand_wrist_index(fake_poses: List[Pose]): + for pose in fake_poses: + for hand in ["LEFT", "RIGHT"]: + detected_format = detect_known_pose_format(pose) + if detected_format == "openpose_135": + with pytest.raises(NotImplementedError, match="Unsupported pose header schema"): + index = get_body_hand_wrist_index(pose, hand) + # TODO: what are the expected values? + else: + index = get_body_hand_wrist_index(pose, hand) + + + +@pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"]) +def test_correct_wrists(fake_poses: List[Pose]): + for pose in fake_poses: + detected_format = detect_known_pose_format(pose) + if detected_format == "openpose_135": + with pytest.raises(NotImplementedError, match="Unsupported pose header schema"): + corrected_pose = correct_wrists(pose) + + else: + corrected_pose = correct_wrists(pose) + assert corrected_pose != pose + assert np.array_equal(corrected_pose.body.data, pose.body.data) is False + + + + +@pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"]) +def test_hands_components(fake_poses: List[Pose]): + for pose in fake_poses: + detected_format = detect_known_pose_format(pose) + if detected_format == "openpose_135": + with pytest.raises(NotImplementedError, match="Unsupported pose header schema"): + hands_components_returned = hands_components(pose.header) + else: + hands_components_returned = hands_components(pose.header) + assert "LEFT" in hands_components_returned[0][0].upper() + assert "RIGHT" in hands_components_returned[0][1].upper() + + +@pytest.mark.parametrize("known_pose_format", TEST_POSE_FORMATS) +def test_fake_pose(known_pose_format: KnownPoseFormat): + + for frame_count in [1, 10, 100]: + for fps in [1, 15, 25, 100]: + standard_components = get_standard_components_for_known_format(known_pose_format) + + pose = fake_pose(frame_count, fps=fps, components=standard_components) + point_formats = [c.format for c in pose.header.components] + data_dimension_expected = 0 + + # they should all be consistent + for point_format in point_formats: + # something like "XYC" or "XYZC" + assert point_format == point_formats[0] + + data_dimension_expected = len(point_formats[0]) - 1 + + + detected_format = detect_known_pose_format(pose) + + if detected_format == 'holistic': + assert point_formats[0] == "XYZC" + elif detected_format == 'openpose': + assert point_formats[0] == "XYC" + elif detected_format == 'openpose_135': + assert point_formats[0] == "XYC" + + assert detected_format == known_pose_format + assert pose.body.fps == fps + assert pose.body.data.shape == (frame_count, 1, pose.header.total_points(), data_dimension_expected) + assert pose.body.data.shape[0] == frame_count + assert pose.header.num_dims() == pose.body.data.shape[-1] + + poses = [fake_pose(25) for _ in range(5)] + + + + + + \ No newline at end of file diff --git a/src/python/pose_format/utils/openpose_135.py b/src/python/pose_format/utils/openpose_135.py index 9acfa3c..0eddf8a 100644 --- a/src/python/pose_format/utils/openpose_135.py +++ b/src/python/pose_format/utils/openpose_135.py @@ -1,4 +1,4 @@ -from pose_format import Pose +from pose_format.pose import Pose from pose_format.pose_header import (PoseHeader, PoseHeaderComponent, PoseHeaderDimensions) from pose_format.utils.openpose import limbs_index, load_openpose_directory diff --git a/src/python/pose_format/utils/pose_converter.py b/src/python/pose_format/utils/pose_converter.py index 46566ca..04397f1 100644 --- a/src/python/pose_format/utils/pose_converter.py +++ b/src/python/pose_format/utils/pose_converter.py @@ -2,7 +2,7 @@ import numpy as np -from pose_format import Pose, PoseHeader +from pose_format.pose import Pose, PoseHeader from pose_format.numpy import NumPyPoseBody from pose_format.pose_header import PoseHeaderComponent from pose_format.pose_visualizer import PoseVisualizer diff --git a/src/python/pyproject.toml b/src/python/pyproject.toml index 581fa8e..0bf4f34 100644 --- a/src/python/pyproject.toml +++ b/src/python/pyproject.toml @@ -58,6 +58,16 @@ testpaths = [ based_on_style = "google" column_limit = 120 +[tool.pylint.format] +max-line-length = 120 +disable = [ + "C0114", # Missing module docstring + "C0115", # Missing class docstring + "C0116", # Missing function or method docstring + "W0511", # TODO + "W1203", # use lazy % formatting in logging functions +] + [project.scripts] pose_info = "pose_format.bin.pose_info:main" video_to_pose = "pose_format.bin.pose_estimation:main" diff --git a/src/python/tests/visualization_test.py b/src/python/tests/visualization_test.py index 9b9fb01..f5a24a2 100644 --- a/src/python/tests/visualization_test.py +++ b/src/python/tests/visualization_test.py @@ -2,7 +2,7 @@ import os from unittest import TestCase -from pose_format import Pose +from pose_format.pose import Pose from pose_format.pose_visualizer import PoseVisualizer