Skip to content

Commit

Permalink
Feature/detect format function (#144)
Browse files Browse the repository at this point in the history
* CDL: minor doc typo fix

* Undoing some changes that got mixed in

* Add detect_pose_format function and SupportedPoseFormat Literal

* detect_known_pose_format and tests for it.

* various cleanup changes, style changes

* missing import

* undo black formatting for face contours and ignore_names

* SupportedPoseFormat->KnownPoseFormat

* Unreachable raise ValueErrors fixed

* generic utils type annotations

* change detect_known_format to take Pose or PoseHeader

* Reraise ImportError if mediapipe is not installed

* conftest update to supply unknown-format fake poses

* nicer formatting for plane_info and line_info

* fix import in generic_test.py

* add some pylint disables, consistent with pose-evaluation

* Change import in conftest.py

* change import style in generic.py

* change more imports

* Fix a few type issues

* Change matrix strategy fail-fast to false, so that we can still run tests if Python 3.8 does not work

* Union for type annotation backwards compatibility

* Add checks for NotImplementedError

* Fix correct_wrist modifying input, and wrong shape for stacked conf. Also added a function to check fake_pose and its outputs

* Simplify get_component_names and fix spacing

* fix test_get_component_names
  • Loading branch information
cleong110 authored Jan 14, 2025
1 parent 56e6717 commit f319e3e
Show file tree
Hide file tree
Showing 12 changed files with 419 additions and 74 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
fail-fast: false

steps:
- uses: actions/checkout@v3
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
.idea/
.DS_Store
.vscode/
.coverage
.coveragerc
coverage.lcov
2 changes: 1 addition & 1 deletion src/python/ComfyUI-Pose-Format/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import cv2
import torch
from pose_format import Pose
from pose_format.pose import Pose
from pose_format.pose_visualizer import PoseVisualizer
from pose_format.utils.generic import reduce_holistic
from pose_format.utils.openpose import OpenPose_Components
Expand Down
2 changes: 1 addition & 1 deletion src/python/pose_format/bin/pose_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import os

from pose_format import Pose
from pose_format.pose import Pose
from pose_format.pose_visualizer import PoseVisualizer
from pose_format.utils.generic import pose_normalization_info

Expand Down
6 changes: 3 additions & 3 deletions src/python/pose_format/pose.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import chain
from typing import BinaryIO, Dict, List, Tuple, Type
from typing import BinaryIO, Dict, List, Tuple, Type, Union

import numpy as np
import numpy.ma as ma
Expand Down Expand Up @@ -87,7 +87,7 @@ def focus(self):
dimensions = (maxs - mins).tolist()
self.header.dimensions = PoseHeaderDimensions(*dimensions)

def normalize(self, info: PoseNormalizationInfo|None=None, scale_factor: float = 1) -> "Pose":
def normalize(self, info: Union[PoseNormalizationInfo,None]=None, scale_factor: float = 1) -> "Pose":
"""
Normalize the points to a fixed distance between two particular points.
Expand Down Expand Up @@ -203,7 +203,7 @@ def frame_dropout_normal(self, dropout_mean: float = 0.5, dropout_std: float = 0
body, selected_indexes = self.body.frame_dropout_normal(dropout_mean=dropout_mean, dropout_std=dropout_std)
return Pose(header=self.header, body=body), selected_indexes

def get_components(self, components: List[str], points: Dict[str, List[str]] = None):
def get_components(self, components: List[str], points: Union[Dict[str, List[str]],None] = None):
"""
get pose components based on criteria.
Expand Down
23 changes: 23 additions & 0 deletions src/python/pose_format/utils/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import copy
from typing import List, get_args
import pytest
from pose_format.pose import Pose
from pose_format.utils.generic import get_standard_components_for_known_format, fake_pose, KnownPoseFormat

@pytest.fixture
def fake_poses(request) -> List[Pose]:
# Access the parameter passed to the fixture
known_format = request.param
count = getattr(request, "count", 3)
known_formats = get_args(KnownPoseFormat)
if known_format in known_formats:

components = get_standard_components_for_known_format(known_format)
return copy.deepcopy([fake_pose(i * 10 + 10, components=components) for i in range(count)])
else:
# get openpose
fake_poses_list = [fake_pose(i * 10 + 10) for i in range(count)]
for i, pose in enumerate(fake_poses_list):
for component in pose.header.components:
component.name = f"unknown_component_{i}_formerly_{component.name}"
return copy.deepcopy(fake_poses_list)
Loading

0 comments on commit f319e3e

Please sign in to comment.