Skip to content

Commit

Permalink
more py tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bipinkrish committed Mar 22, 2024
1 parent 581d4ce commit 55d5174
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 10 deletions.
15 changes: 9 additions & 6 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
name: Build
name: Python Tests

on:
push:
branches: [ master, main ]
branches: [master, main]
pull_request:
branches: [ master, main ]

branches: [master, main]

jobs:
test:
name: Run Tests
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ "3.10" ]
python-version: ["3.10"]

steps:
- uses: actions/checkout@v3
Expand All @@ -27,4 +26,8 @@ jobs:

- name: Run tests
working-directory: src/python
run: pytest pose_format
run: pytest pose_format

- name: Run additional tests
working-directory: src/python
run: pytest tests -s
4 changes: 2 additions & 2 deletions src/python/tests/hand_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_hand_normalization(self):
"""
Test the normalization of hand pose data using the PoseNormalizer.
"""
with open('data/mediapipe.pose', 'rb') as f:
with open('tests/data/mediapipe.pose', 'rb') as f:
pose = Pose.read(f.read())
pose = pose.get_components(["RIGHT_HAND_LANDMARKS"])

Expand All @@ -80,7 +80,7 @@ def test_hand_normalization(self):
pose.body.data = tensor
pose.focus()

with open('data/mediapipe_hand_normalized.pose', 'rb') as f:
with open('tests/data/mediapipe_hand_normalized.pose', 'rb') as f:
pose_gold = Pose.read(f.read())

self.assertTrue(ma.allclose(pose.body.data, pose_gold.body.data))
Expand Down
4 changes: 2 additions & 2 deletions src/python/tests/optical_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_optical_flow(self):
"""
calculator = OpticalFlowCalculator(fps=30, distance=DistanceRepresentation())

with open('data/mediapipe.pose', 'rb') as f:
with open('tests/data/mediapipe.pose', 'rb') as f:
pose = Pose.read(f.read())
pose = pose.get_components(["POSE_LANDMARKS", "RIGHT_HAND_LANDMARKS", "LEFT_HAND_LANDMARKS"])

Expand All @@ -44,4 +44,4 @@ def test_optical_flow(self):
fp = tempfile.NamedTemporaryFile()
plt.savefig(fp.name, format='png')

self.assertTrue(compare_images('data/optical_flow.png', fp.name, 0.001) is None)
self.assertTrue(compare_images('tests/data/optical_flow.png', fp.name, 0.001) is None)
56 changes: 56 additions & 0 deletions src/python/tests/visualization_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import tempfile
import os
from unittest import TestCase

from pose_format import Pose
from pose_format.pose_visualizer import PoseVisualizer


class TestPoseVisualizer(TestCase):
"""
Test cases for PoseVisualizer functionality.
"""

def test_save_gif(self):
"""
Test saving pose visualization as GIF.
"""
with open("tests/data/mediapipe.pose", "rb") as f:
pose = Pose.read(f.read())

v = PoseVisualizer(pose)

with tempfile.NamedTemporaryFile(suffix='.gif', delete=False) as temp_gif:
v.save_gif(temp_gif.name, v.draw())
self.assertTrue(os.path.exists(temp_gif.name))
self.assertGreater(os.path.getsize(
temp_gif.name), 0)

def test_save_png(self):
"""
Test saving pose visualization as PNG.
"""
with open("tests/data/mediapipe_long.pose", "rb") as f:
pose = Pose.read(f.read())

v = PoseVisualizer(pose)

with tempfile.TemporaryDirectory() as temp_dir:
temp_png = os.path.join(temp_dir, 'example.png')
v.save_png(temp_png, v.draw(transparency=True))
self.assertTrue(os.path.exists(temp_png))
self.assertGreater(os.path.getsize(temp_png), 0)

def test_save_mp4(self):
"""
Test saving pose visualization as MP4 video.
"""
with open("tests/data/mediapipe_hand_normalized.pose", "rb") as f:
pose = Pose.read(f.read())

v = PoseVisualizer(pose)

with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_mp4:
v.save_video(temp_mp4.name, v.draw())
self.assertTrue(os.path.exists(temp_mp4.name))
self.assertGreater(os.path.getsize(temp_mp4.name), 0)

0 comments on commit 55d5174

Please sign in to comment.