Skip to content

Commit

Permalink
Fix correct_wrist modifying input, and wrong shape for stacked conf. …
Browse files Browse the repository at this point in the history
…Also added a function to check fake_pose and its outputs
  • Loading branch information
cleong110 committed Jan 13, 2025
1 parent 5480987 commit a0bb83a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 6 deletions.
19 changes: 14 additions & 5 deletions src/python/pose_format/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,23 @@ def get_standard_components_for_known_format(known_pose_format: KnownPoseFormat)
raise NotImplementedError(f"Unsupported pose header schema {known_pose_format}")


def fake_pose(num_frames: int, fps=25, dims=2, components=None)->Pose:
def fake_pose(num_frames: int, fps: int=25, components: Union[List[PoseHeaderComponent],None]=None)->Pose:
if components is None:
components = copy.deepcopy(OpenPose_Components) # fixes W0102, dangerous default value
dimensions = PoseHeaderDimensions(width=1, height=1, depth=1)

if components[0].format == "XYZC":
dimensions = PoseHeaderDimensions(width=1, height=1, depth=1)
elif components[0].format == "XYC":
dimensions = PoseHeaderDimensions(width=1, height=1)
else:
raise ValueError(f"Unknown point format: {components[0].format}")
header = PoseHeader(version=0.2, dimensions=dimensions, components=components)

total_points = header.total_points()
data = np.random.randn(num_frames, 1, total_points, dims)
data = np.random.randn(num_frames, 1, total_points, header.num_dims())
confidence = np.random.randn(num_frames, 1, total_points)
masked_data = ma.masked_array(data)


body = NumPyPoseBody(fps=int(fps), data=masked_data, confidence=confidence)

Expand Down Expand Up @@ -237,6 +244,7 @@ def get_body_hand_wrist_index(pose: Pose, hand: str)-> int:


def correct_wrist(pose: Pose, hand: str) -> Pose:
pose = copy.deepcopy(pose) # was previously modifying the input
wrist_index = get_hand_wrist_index(pose, hand)
wrist = pose.body.data[:, :, wrist_index]
wrist_conf = pose.body.confidence[:, :, wrist_index]
Expand All @@ -245,13 +253,14 @@ def correct_wrist(pose: Pose, hand: str) -> Pose:
body_wrist = pose.body.data[:, :, body_wrist_index]
body_wrist_conf = pose.body.confidence[:, :, body_wrist_index]

stacked_conf = np.stack([wrist_conf] * 3, axis=-1)
point_coordinate_count = wrist.shape[-1]
stacked_conf = np.stack([wrist_conf] * point_coordinate_count, axis=-1)
new_wrist_data = ma.where(stacked_conf == 0, body_wrist, wrist)
new_wrist_conf = ma.where(wrist_conf == 0, body_wrist_conf, wrist_conf)

pose.body.data[:, :, body_wrist_index] = new_wrist_data
pose.body.confidence[:, :, body_wrist_index] = new_wrist_conf
return pose
return pose


def correct_wrists(pose: Pose) -> Pose:
Expand Down
48 changes: 47 additions & 1 deletion src/python/pose_format/utils/generic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_body_hand_wrist_index,
correct_wrists,
hands_components,
fake_pose,
)

TEST_POSE_FORMATS = list(get_args(KnownPoseFormat))
Expand Down Expand Up @@ -154,8 +155,10 @@ def test_correct_wrists(fake_poses: List[Pose]):

else:
corrected_pose = correct_wrists(pose)
assert np.array_equal(corrected_pose.body.data, pose.body.data) is False
assert corrected_pose != pose
assert np.array_equal(corrected_pose.body.data, pose.body.data) is False




@pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"])
Expand All @@ -169,3 +172,46 @@ def test_hands_components(fake_poses: List[Pose]):
hands_components_returned = hands_components(pose.header)
assert "LEFT" in hands_components_returned[0][0].upper()
assert "RIGHT" in hands_components_returned[0][1].upper()


@pytest.mark.parametrize("known_pose_format", TEST_POSE_FORMATS)
def test_fake_pose(known_pose_format: KnownPoseFormat):

for frame_count in [1, 10, 100]:
for fps in [1, 15, 25, 100]:
standard_components = get_standard_components_for_known_format(known_pose_format)

pose = fake_pose(frame_count, fps=fps, components=standard_components)
point_formats = [c.format for c in pose.header.components]
data_dimension_expected = 0

# they should all be consistent
for point_format in point_formats:
# something like "XYC" or "XYZC"
assert point_format == point_formats[0]

data_dimension_expected = len(point_formats[0]) - 1


detected_format = detect_known_pose_format(pose)

if detected_format == 'holistic':
assert point_formats[0] == "XYZC"
elif detected_format == 'openpose':
assert point_formats[0] == "XYC"
elif detected_format == 'openpose_135':
assert point_formats[0] == "XYC"

assert detected_format == known_pose_format
assert pose.body.fps == fps
assert pose.body.data.shape == (frame_count, 1, pose.header.total_points(), data_dimension_expected)
assert pose.body.data.shape[0] == frame_count
assert pose.header.num_dims() == pose.body.data.shape[-1]

poses = [fake_pose(25) for _ in range(5)]






0 comments on commit a0bb83a

Please sign in to comment.