Skip to content

Commit

Permalink
feat(pose_estimation): support multiprocessing videos_to_poses
Browse files Browse the repository at this point in the history
  • Loading branch information
J22Melody committed Dec 5, 2024
1 parent 1845ea4 commit 999a598
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 15 deletions.
87 changes: 76 additions & 11 deletions src/python/pose_format/bin/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,43 @@
from typing import List
import logging
from tqdm import tqdm
import multiprocessing
from multiprocessing import Pool
import multiprocessing.pool as mpp
import psutil
import os

# os.system("taskset -p 0xff %d" % os.getpid())

# Note: untested other than .mp4. Support for .webm may have issues: https://github.com/sign-language-processing/pose/pull/126
SUPPORTED_VIDEO_FORMATS = [".mp4", ".mov", ".avi", ".mkv", ".flv", ".wmv", ".webm"]


# https://stackoverflow.com/questions/57354700/starmap-combined-with-tqdm

def istarmap(self, func, iterable, chunksize=1):
"""starmap-version of imap
"""
self._check_running()
if chunksize < 1:
raise ValueError(
"Chunksize must be 1+, not {0:n}".format(
chunksize))

task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
result = mpp.IMapIterator(self)
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
mpp.starmapstar,
task_batches),
result._set_length
))
return (item for chunk in result for item in chunk)

mpp.Pool.istarmap = istarmap


def find_videos_with_missing_pose_files(
directory: Path,
video_suffixes: List[str] = None,
Expand Down Expand Up @@ -79,6 +111,21 @@ def get_corresponding_pose_path(video_path: Path, keep_video_suffixes: bool = Fa
return video_path.with_suffix(".pose")


def process_video(vid_path: Path, keep_video_suffixes: bool, pose_format: str, additional_config: dict) -> bool:
print(f'Estimating {vid_path} on CPU {psutil.Process().cpu_num()}')
try:
pose_path = get_corresponding_pose_path(video_path=vid_path, keep_video_suffixes=keep_video_suffixes)
if pose_path.is_file():
print(f"Skipping {vid_path}, corresponding .pose file already created.")
else:
pose_video(vid_path, pose_path, pose_format, additional_config, progress=False)
return True

except ValueError as e:
print(f"ValueError on {vid_path}")
logging.exception(e)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -114,6 +161,13 @@ def main():
default=SUPPORTED_VIDEO_FORMATS,
help="Video extensions to search for. Defaults to searching for all supported.",
)
parser.add_argument(
"--num-workers",
type=int,
default=1,
help="Number of multiprocessing workers.",
required=False
)
parser.add_argument(
"--additional-config",
type=str,
Expand Down Expand Up @@ -144,15 +198,26 @@ def main():

pose_with_no_errors_count = 0

for vid_path in tqdm(videos_with_missing_pose_files):
try:
pose_path = get_corresponding_pose_path(video_path=vid_path, keep_video_suffixes=args.keep_video_suffixes)
if pose_path.is_file():
print(f"Skipping {vid_path}, corresponding .pose file already created.")
continue
pose_video(vid_path, pose_path, args.format, additional_config)
pose_with_no_errors_count += 1
except ValueError as e:
print(f"ValueError on {vid_path}")
logging.exception(e)
if args.num_workers == 1:
print('Process sequentially ...')

for vid_path in tqdm(videos_with_missing_pose_files):
success = process_video(vid_path, args.keep_video_suffixes, args.format, additional_config)
if success:
pose_with_no_errors_count += 1
else:
print(f'Multiprocessing with {args.num_workers} workers on {len(os.sched_getaffinity(0))} available CPUs ...')

with Pool(args.num_workers) as pool:
params = [[
vid_path,
args.keep_video_suffixes,
args.format,
additional_config,
] for vid_path in videos_with_missing_pose_files]

for success in tqdm(pool.istarmap(process_video, params), total=len(params)):
if success:
pose_with_no_errors_count += 1

print(f"Successfully created pose files for {pose_with_no_errors_count}/{len(videos_with_missing_pose_files)} video files")
7 changes: 3 additions & 4 deletions src/python/pose_format/bin/pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def load_video_frames(cap: cv2.VideoCapture):
cap.release()


def pose_video(input_path: str, output_path: str, format: str, additional_config: dict):
def pose_video(input_path: str, output_path: str, format: str, additional_config: dict = {'model_complexity': 1}, progress: bool = True):
# Load video frames
print('Loading video ...')
cap = cv2.VideoCapture(input_path)
Expand All @@ -27,13 +27,12 @@ def pose_video(input_path: str, output_path: str, format: str, additional_config
# Perform pose estimation
print('Estimating pose ...')
if format == 'mediapipe':
additional_holistic_config = {'model_complexity': 1} | additional_config
pose = load_holistic(frames,
fps=fps,
width=width,
height=height,
progress=True,
additional_holistic_config=additional_holistic_config)
progress=progress,
additional_holistic_config=additional_config)
else:
raise NotImplementedError('Pose format not supported')

Expand Down

0 comments on commit 999a598

Please sign in to comment.