Skip to content

Commit

Permalink
feat(normalization): add mean pose calculation, recalculate for impro…
Browse files Browse the repository at this point in the history
…ved face model
  • Loading branch information
AmitMY committed Nov 6, 2024
1 parent 58aa25c commit fd10f15
Show file tree
Hide file tree
Showing 7 changed files with 3,966 additions and 3,248 deletions.
11 changes: 10 additions & 1 deletion pose_anonymization/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from pathlib import Path

import numpy as np
from pose_format import PoseHeader
from pose_format import PoseHeader, Pose
from pose_format.numpy import NumPyPoseBody
from pose_format.utils.reader import BufferReader

CURRENT_DIR = Path(__file__).parent
Expand Down Expand Up @@ -31,3 +32,11 @@ def load_mean_and_std():
std[std == 0] = 1

return np.array(mean), std

@lru_cache(maxsize=1)
def load_mean_and_std_pose():
pose_header = load_pose_header()
mean, std = load_mean_and_std()
mean_body = NumPyPoseBody(fps=1, data=mean.reshape((1, 1, -1, 3)), confidence=np.ones((1, 1, len(mean))))
std_body = NumPyPoseBody(fps=1, data=std.reshape((1, 1, -1, 3)), confidence=np.ones((1, 1, len(std))))
return Pose(header=pose_header, body=mean_body), Pose(header=pose_header, body=std_body)
90 changes: 90 additions & 0 deletions pose_anonymization/data/calc_mean_std.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from functools import partial
from pathlib import Path

import numpy as np
from pose_format import Pose, PoseHeader
from tqdm.contrib.concurrent import process_map

from pose_anonymization.data.normalization import pre_process_pose

CURRENT_DIR = Path(__file__).parent


def process_file(file, pose_header: PoseHeader):
with open(file, 'rb') as pose_file:
pose = Pose.read(pose_file.read())
pose = pre_process_pose(pose, pose_header=pose_header)
tensor = pose.body.data.filled(0)

frames_sum = np.sum(tensor, axis=(0, 1))
frames_squared_sum = np.sum(np.square(tensor), axis=(0, 1))
unmasked_frames = pose.body.data[:, :, :, 0:1].mask == False
num_unmasked_frames = np.sum(unmasked_frames, axis=(0, 1))

return frames_sum, frames_squared_sum, num_unmasked_frames


def calc_mean_and_std(files, pose_header: PoseHeader):
cumulative_sum, squared_sum, frames_count = None, None, None

process_func = partial(process_file, pose_header=pose_header)
results = process_map(process_func, files, max_workers=None, chunksize=1)

for frames_sum, frames_squared_sum, num_unmasked_frames in results:
cumulative_sum = frames_sum if cumulative_sum is None else cumulative_sum + frames_sum
squared_sum = frames_squared_sum if squared_sum is None else squared_sum + frames_squared_sum
frames_count = num_unmasked_frames if frames_count is None else frames_count + num_unmasked_frames

mean = cumulative_sum / frames_count
std = np.sqrt((squared_sum / frames_count) - np.square(mean))

return mean, std


def main(poses_location: str):
print("Listing files...")
files = list(Path(poses_location).glob("*.pose"))
print(f"Processing {len(files)} files")

# get a single random pose
with open(files[0], 'rb') as pose_file:
pose = Pose.read(pose_file.read())
pose = pre_process_pose(pose)

mean, std = calc_mean_and_std(files, pose.header)

# store header
with open(CURRENT_DIR / "header.poseheader", "wb") as f:
pose.header.write(f)

i = 0
mean_std_info = {}
for component in pose.header.components:
component_info = {}
for point in component.points:
component_info[point] = {
"mean": mean[i].tolist(),
"std": std[i].tolist()
}
i += 1
mean_std_info[component.name] = component_info

import json

with open(CURRENT_DIR / "pose_normalization.json", "w", encoding="utf-8") as f:
json.dump(mean_std_info, f, indent=2)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description='Collect normalization info')
parser.add_argument('--dir', type=str, help='Directory containing the pose files',
default="/Volumes/Echo/GCS/sign-mt-poses")

args = parser.parse_args()

if not Path(args.dir).exists():
raise FileNotFoundError(f"Directory {args.dir} does not exist")

main(args.dir)
25 changes: 25 additions & 0 deletions pose_anonymization/data/draw_mean_pose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pathlib import Path

import cv2
from pose_format.pose_visualizer import PoseVisualizer
from pose_format.utils.generic import reduce_holistic

from pose_anonymization.appearance import normalize_pose_size
from pose_anonymization.data import load_mean_and_std_pose
from pose_anonymization.data.normalization import unshift_hands

if __name__ == "__main__":
mean_pose, _ = load_mean_and_std_pose()
unshift_hands(mean_pose)
normalize_pose_size(mean_pose)

poses = {
"full": mean_pose,
"reduced": reduce_holistic(mean_pose)
}
for name, pose in poses.items():
pose.focus()

v = PoseVisualizer(pose)
image_path = Path(__file__).parent / f"mean_pose_{name}.png"
cv2.imwrite(str(image_path), next(v.draw()))
Binary file modified pose_anonymization/data/header.poseheader
Binary file not shown.
76 changes: 76 additions & 0 deletions pose_anonymization/data/normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from pose_format import Pose, PoseHeader
from pose_format.utils.generic import pose_normalization_info, correct_wrists, hands_components

from pose_anonymization.data import load_mean_and_std, load_mean_and_std_pose


def shift_hand(pose: Pose, hand_component: str, wrist_name: str):
# pylint: disable=protected-access
wrist_index = pose.header._get_point_index(hand_component, wrist_name)
hand = pose.body.data[:, :, wrist_index: wrist_index + 21]
wrist = hand[:, :, 0:1]
pose.body.data[:, :, wrist_index: wrist_index + 21] = hand - wrist


def shift_hands(pose: Pose):
(left_hand_component, right_hand_component), _, (wrist, _) = hands_components(pose.header)
shift_hand(pose, left_hand_component, wrist)
shift_hand(pose, right_hand_component, wrist)


def unshift_hand(pose: Pose, hand_component: str):
# pylint: disable=protected-access
wrist_index = pose.header._get_point_index(hand_component, "WRIST")
hand = pose.body.data[:, :, wrist_index: wrist_index + 21]
body_wrist_name = "LEFT_WRIST" if hand_component == "LEFT_HAND_LANDMARKS" else "RIGHT_WRIST"
# pylint: disable=protected-access
body_wrist_index = pose.header._get_point_index("POSE_LANDMARKS", body_wrist_name)
body_wrist = pose.body.data[:, :, body_wrist_index: body_wrist_index + 1]
pose.body.data[:, :, wrist_index: wrist_index + 21] = hand + body_wrist


def unshift_hands(pose: Pose):
(left_hand_component, right_hand_component), _, _ = hands_components(pose.header)
unshift_hand(pose, left_hand_component)
unshift_hand(pose, right_hand_component)


def pose_like(pose: Pose, pose_header: PoseHeader):
component_names = [component.name for component in pose_header.components]
component_points = {component.name: component.points for component in pose.header.components}
return pose.get_components(component_names, component_points)


def pre_process_pose(pose: Pose, pose_header: PoseHeader = None):
if pose_header is not None:
pose = pose_like(pose, pose_header)

# Align hand wrists with body wrists
correct_wrists(pose)
# Adjust pose based on shoulder positions
pose = pose.normalize(pose_normalization_info(pose.header))
# Shift hands to origin
shift_hands(pose)
return pose


def load_mean_and_std_for_pose(pose: Pose):
mean_pose, std_pose = load_mean_and_std_pose()
mean_pose = pose_like(mean_pose, pose.header)
std_pose = pose_like(std_pose, pose.header)
return (mean_pose.body.data.reshape((-1, 3)),
std_pose.body.data.reshape((-1, 3)))


def normalize_mean_std(pose: Pose):
pose = pre_process_pose(pose)
mean, std = load_mean_and_std_for_pose(pose)
pose.body.data = (pose.body.data - mean) / std
return pose


def unnormalize_mean_std(pose: Pose):
mean, std = load_mean_and_std()
pose.body.data = (pose.body.data * std) + mean
unshift_hands(pose)
return pose
Loading

0 comments on commit fd10f15

Please sign in to comment.