Skip to content

Commit

Permalink
feat(pose_estimation): support multiprocessing videos_to_poses (#130)
Browse files Browse the repository at this point in the history
* feat(pose_estimation): support multiprocessing videos_to_poses

* review: use tqdm process_map
  • Loading branch information
J22Melody authored Dec 6, 2024
1 parent 1845ea4 commit 7a95e70
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 14 deletions.
46 changes: 36 additions & 10 deletions src/python/pose_format/bin/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from typing import List
import logging
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
import psutil
import os
from functools import partial

# Note: untested other than .mp4. Support for .webm may have issues: https://github.com/sign-language-processing/pose/pull/126
SUPPORTED_VIDEO_FORMATS = [".mp4", ".mov", ".avi", ".mkv", ".flv", ".wmv", ".webm"]
Expand Down Expand Up @@ -79,6 +83,22 @@ def get_corresponding_pose_path(video_path: Path, keep_video_suffixes: bool = Fa
return video_path.with_suffix(".pose")


def process_video(keep_video_suffixes: bool, pose_format: str, additional_config: dict, vid_path: Path) -> bool:
print(f'Estimating {vid_path} on CPU {psutil.Process().cpu_num()}')

try:
pose_path = get_corresponding_pose_path(video_path=vid_path, keep_video_suffixes=keep_video_suffixes)
if pose_path.is_file():
print(f"Skipping {vid_path}, corresponding .pose file already created.")
else:
pose_video(vid_path, pose_path, pose_format, additional_config, progress=False)
return True

except ValueError as e:
print(f"ValueError on {vid_path}")
logging.exception(e)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -114,6 +134,13 @@ def main():
default=SUPPORTED_VIDEO_FORMATS,
help="Video extensions to search for. Defaults to searching for all supported.",
)
parser.add_argument(
"--num-workers",
type=int,
default=1,
help="Number of multiprocessing workers.",
required=False
)
parser.add_argument(
"--additional-config",
type=str,
Expand Down Expand Up @@ -144,15 +171,14 @@ def main():

pose_with_no_errors_count = 0

for vid_path in tqdm(videos_with_missing_pose_files):
try:
pose_path = get_corresponding_pose_path(video_path=vid_path, keep_video_suffixes=args.keep_video_suffixes)
if pose_path.is_file():
print(f"Skipping {vid_path}, corresponding .pose file already created.")
continue
pose_video(vid_path, pose_path, args.format, additional_config)
if args.num_workers == 1:
print('Process sequentially ...')
else:
print(f'Multiprocessing with {args.num_workers} workers on {len(os.sched_getaffinity(0))} available CPUs ...')

func = partial(process_video, args.keep_video_suffixes, args.format, additional_config)
for success in process_map(func, videos_with_missing_pose_files, max_workers=args.num_workers):
if success:
pose_with_no_errors_count += 1
except ValueError as e:
print(f"ValueError on {vid_path}")
logging.exception(e)

print(f"Successfully created pose files for {pose_with_no_errors_count}/{len(videos_with_missing_pose_files)} video files")
7 changes: 3 additions & 4 deletions src/python/pose_format/bin/pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def load_video_frames(cap: cv2.VideoCapture):
cap.release()


def pose_video(input_path: str, output_path: str, format: str, additional_config: dict):
def pose_video(input_path: str, output_path: str, format: str, additional_config: dict = {'model_complexity': 1}, progress: bool = True):
# Load video frames
print('Loading video ...')
cap = cv2.VideoCapture(input_path)
Expand All @@ -27,13 +27,12 @@ def pose_video(input_path: str, output_path: str, format: str, additional_config
# Perform pose estimation
print('Estimating pose ...')
if format == 'mediapipe':
additional_holistic_config = {'model_complexity': 1} | additional_config
pose = load_holistic(frames,
fps=fps,
width=width,
height=height,
progress=True,
additional_holistic_config=additional_holistic_config)
progress=progress,
additional_holistic_config=additional_config)
else:
raise NotImplementedError('Pose format not supported')

Expand Down

0 comments on commit 7a95e70

Please sign in to comment.