-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(normalization): add mean pose calculation, recalculate for impro…
…ved face model
- Loading branch information
Showing
7 changed files
with
3,966 additions
and
3,248 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from functools import partial | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
from pose_format import Pose, PoseHeader | ||
from tqdm.contrib.concurrent import process_map | ||
|
||
from pose_anonymization.data.normalization import pre_process_pose | ||
|
||
CURRENT_DIR = Path(__file__).parent | ||
|
||
|
||
def process_file(file, pose_header: PoseHeader): | ||
with open(file, 'rb') as pose_file: | ||
pose = Pose.read(pose_file.read()) | ||
pose = pre_process_pose(pose, pose_header=pose_header) | ||
tensor = pose.body.data.filled(0) | ||
|
||
frames_sum = np.sum(tensor, axis=(0, 1)) | ||
frames_squared_sum = np.sum(np.square(tensor), axis=(0, 1)) | ||
unmasked_frames = pose.body.data[:, :, :, 0:1].mask == False | ||
num_unmasked_frames = np.sum(unmasked_frames, axis=(0, 1)) | ||
|
||
return frames_sum, frames_squared_sum, num_unmasked_frames | ||
|
||
|
||
def calc_mean_and_std(files, pose_header: PoseHeader): | ||
cumulative_sum, squared_sum, frames_count = None, None, None | ||
|
||
process_func = partial(process_file, pose_header=pose_header) | ||
results = process_map(process_func, files, max_workers=None, chunksize=1) | ||
|
||
for frames_sum, frames_squared_sum, num_unmasked_frames in results: | ||
cumulative_sum = frames_sum if cumulative_sum is None else cumulative_sum + frames_sum | ||
squared_sum = frames_squared_sum if squared_sum is None else squared_sum + frames_squared_sum | ||
frames_count = num_unmasked_frames if frames_count is None else frames_count + num_unmasked_frames | ||
|
||
mean = cumulative_sum / frames_count | ||
std = np.sqrt((squared_sum / frames_count) - np.square(mean)) | ||
|
||
return mean, std | ||
|
||
|
||
def main(poses_location: str): | ||
print("Listing files...") | ||
files = list(Path(poses_location).glob("*.pose")) | ||
print(f"Processing {len(files)} files") | ||
|
||
# get a single random pose | ||
with open(files[0], 'rb') as pose_file: | ||
pose = Pose.read(pose_file.read()) | ||
pose = pre_process_pose(pose) | ||
|
||
mean, std = calc_mean_and_std(files, pose.header) | ||
|
||
# store header | ||
with open(CURRENT_DIR / "header.poseheader", "wb") as f: | ||
pose.header.write(f) | ||
|
||
i = 0 | ||
mean_std_info = {} | ||
for component in pose.header.components: | ||
component_info = {} | ||
for point in component.points: | ||
component_info[point] = { | ||
"mean": mean[i].tolist(), | ||
"std": std[i].tolist() | ||
} | ||
i += 1 | ||
mean_std_info[component.name] = component_info | ||
|
||
import json | ||
|
||
with open(CURRENT_DIR / "pose_normalization.json", "w", encoding="utf-8") as f: | ||
json.dump(mean_std_info, f, indent=2) | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser(description='Collect normalization info') | ||
parser.add_argument('--dir', type=str, help='Directory containing the pose files', | ||
default="/Volumes/Echo/GCS/sign-mt-poses") | ||
|
||
args = parser.parse_args() | ||
|
||
if not Path(args.dir).exists(): | ||
raise FileNotFoundError(f"Directory {args.dir} does not exist") | ||
|
||
main(args.dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from pathlib import Path | ||
|
||
import cv2 | ||
from pose_format.pose_visualizer import PoseVisualizer | ||
from pose_format.utils.generic import reduce_holistic | ||
|
||
from pose_anonymization.appearance import normalize_pose_size | ||
from pose_anonymization.data import load_mean_and_std_pose | ||
from pose_anonymization.data.normalization import unshift_hands | ||
|
||
if __name__ == "__main__": | ||
mean_pose, _ = load_mean_and_std_pose() | ||
unshift_hands(mean_pose) | ||
normalize_pose_size(mean_pose) | ||
|
||
poses = { | ||
"full": mean_pose, | ||
"reduced": reduce_holistic(mean_pose) | ||
} | ||
for name, pose in poses.items(): | ||
pose.focus() | ||
|
||
v = PoseVisualizer(pose) | ||
image_path = Path(__file__).parent / f"mean_pose_{name}.png" | ||
cv2.imwrite(str(image_path), next(v.draw())) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from pose_format import Pose, PoseHeader | ||
from pose_format.utils.generic import pose_normalization_info, correct_wrists, hands_components | ||
|
||
from pose_anonymization.data import load_mean_and_std, load_mean_and_std_pose | ||
|
||
|
||
def shift_hand(pose: Pose, hand_component: str, wrist_name: str): | ||
# pylint: disable=protected-access | ||
wrist_index = pose.header._get_point_index(hand_component, wrist_name) | ||
hand = pose.body.data[:, :, wrist_index: wrist_index + 21] | ||
wrist = hand[:, :, 0:1] | ||
pose.body.data[:, :, wrist_index: wrist_index + 21] = hand - wrist | ||
|
||
|
||
def shift_hands(pose: Pose): | ||
(left_hand_component, right_hand_component), _, (wrist, _) = hands_components(pose.header) | ||
shift_hand(pose, left_hand_component, wrist) | ||
shift_hand(pose, right_hand_component, wrist) | ||
|
||
|
||
def unshift_hand(pose: Pose, hand_component: str): | ||
# pylint: disable=protected-access | ||
wrist_index = pose.header._get_point_index(hand_component, "WRIST") | ||
hand = pose.body.data[:, :, wrist_index: wrist_index + 21] | ||
body_wrist_name = "LEFT_WRIST" if hand_component == "LEFT_HAND_LANDMARKS" else "RIGHT_WRIST" | ||
# pylint: disable=protected-access | ||
body_wrist_index = pose.header._get_point_index("POSE_LANDMARKS", body_wrist_name) | ||
body_wrist = pose.body.data[:, :, body_wrist_index: body_wrist_index + 1] | ||
pose.body.data[:, :, wrist_index: wrist_index + 21] = hand + body_wrist | ||
|
||
|
||
def unshift_hands(pose: Pose): | ||
(left_hand_component, right_hand_component), _, _ = hands_components(pose.header) | ||
unshift_hand(pose, left_hand_component) | ||
unshift_hand(pose, right_hand_component) | ||
|
||
|
||
def pose_like(pose: Pose, pose_header: PoseHeader): | ||
component_names = [component.name for component in pose_header.components] | ||
component_points = {component.name: component.points for component in pose.header.components} | ||
return pose.get_components(component_names, component_points) | ||
|
||
|
||
def pre_process_pose(pose: Pose, pose_header: PoseHeader = None): | ||
if pose_header is not None: | ||
pose = pose_like(pose, pose_header) | ||
|
||
# Align hand wrists with body wrists | ||
correct_wrists(pose) | ||
# Adjust pose based on shoulder positions | ||
pose = pose.normalize(pose_normalization_info(pose.header)) | ||
# Shift hands to origin | ||
shift_hands(pose) | ||
return pose | ||
|
||
|
||
def load_mean_and_std_for_pose(pose: Pose): | ||
mean_pose, std_pose = load_mean_and_std_pose() | ||
mean_pose = pose_like(mean_pose, pose.header) | ||
std_pose = pose_like(std_pose, pose.header) | ||
return (mean_pose.body.data.reshape((-1, 3)), | ||
std_pose.body.data.reshape((-1, 3))) | ||
|
||
|
||
def normalize_mean_std(pose: Pose): | ||
pose = pre_process_pose(pose) | ||
mean, std = load_mean_and_std_for_pose(pose) | ||
pose.body.data = (pose.body.data - mean) / std | ||
return pose | ||
|
||
|
||
def unnormalize_mean_std(pose: Pose): | ||
mean, std = load_mean_and_std() | ||
pose.body.data = (pose.body.data * std) + mean | ||
unshift_hands(pose) | ||
return pose |
Oops, something went wrong.