Skip to content

Commit

Permalink
update talkshow 241112
Browse files Browse the repository at this point in the history
  • Loading branch information
WC committed Nov 12, 2024
1 parent 0ca5526 commit 764dc0d
Showing 1 changed file with 301 additions and 9 deletions.
310 changes: 301 additions & 9 deletions mmhuman3d/data/data_converters/talkshow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tqdm import tqdm
import cv2
import glob
import pickle
from mmhuman3d.core.conventions.keypoints_mapping import convert_kps
from mmhuman3d.data.data_structures.human_data import HumanData

Expand All @@ -34,6 +35,29 @@ def __init__(self, modes: List = []) -> None:

self.device = torch.device(
'cuda') if torch.cuda.is_available() else torch.device('cpu')

self.misc_config = dict(
bbox_body_scale=1.2,
bbox_facehand_scale=1.0,
bbox_source='keypoints2d_original',
flat_hand_mean=False,
cam_param_type='prespective',
cam_param_source='original',
smplx_source='original',
)

self.smplx_shape = {
'betas': (-1, 10),
'transl': (-1, 3),
'global_orient': (-1, 3),
'body_pose': (-1, 21, 3),
'left_hand_pose': (-1, 15, 3),
'right_hand_pose': (-1, 15, 3),
'leye_pose': (-1, 3),
'reye_pose': (-1, 3),
'jaw_pose': (-1, 3),
'expression': (-1, 10)
}

super(TalkshowConverter, self).__init__(modes)

Expand Down Expand Up @@ -227,14 +251,16 @@ def _world2cam(self, world_coord, R, t):

def convert_by_mode(self, dataset_path: str, out_path: str,
mode: str) -> dict:
# parse data
speakers = ['seth','conan','oliver','chemistry']
# speakers = ['conan']
data_root = "/mnt/lustrenew/share_data/zoetrope/data/datasets/talkshow/ExpressiveWholeBodyDatasetReleaseV1.0"
split = 'train'
vid_root = '/mnt/lustrenew/share_data/zoetrope/data/datasets/talkshow/raw_videos/'
img_root = '/mnt/lustre/share_data/weichen1/talkshow_frames'
root = '/mnt/lustre/share_data/weichen1'


# data_root = "/mnt/lustrenew/share_data/zoetrope/data/datasets/talkshow/ExpressiveWholeBodyDatasetReleaseV1.0"
# split = 'train'
# vid_root = '/mnt/lustrenew/share_data/zoetrope/data/datasets/talkshow/raw_videos/'
# img_root = '/mnt/lustre/share_data/weichen1/talkshow_frames'
# root = '/mnt/lustre/share_data/weichen1'

# use HumanData to store all data
human_data = HumanData()

# build smplx model
smplx_model = build_body_model(
Expand All @@ -245,14 +271,280 @@ def convert_by_mode(self, dataset_path: str, out_path: str,
model_path='data/body_models/smplx',
gender='neutral',
use_face_contour=True,
flat_hand_mean=True,
flat_hand_mean=self.misc_config['flat_hand_mean'],
use_pca=False,
num_betas= 300,
num_expression_coeffs=100)).to(self.device)

# use HumanData to store the data
human_data = HumanData()

# init seed and size
seed, size = '241111', '999'
random.seed(int(seed))
np.set_printoptions(suppress=True)
random_ids = np.random.RandomState(seed=int(seed)).permutation(999999)
used_id_num = 0

# initialize output for human_data
smplx_ = {}
for key in self.smplx_shape.keys():
smplx_[key] = []
keypoints2d_, keypoints3d_ = [], []
bboxs_ = {}
for bbox_name in [
'bbox_xywh', 'face_bbox_xywh', 'lhand_bbox_xywh',
'rhand_bbox_xywh'
]:
bboxs_[bbox_name] = []
meta_ = {}
for meta_name in ['principal_point', 'focal_length', 'height', 'width',
'sequence_name', 'gender', 'RT']:
meta_[meta_name] = []
image_path_ = []

# all 4 speakers
speakers = ['chemistry', 'seth', 'conan', 'oliver',]

for speaker_id, speaker_name in enumerate(speakers):

frames_path = os.path.join(dataset_path, 'talkshow_frames', speaker_name)
annots_path = os.path.join(dataset_path, 'talkshow_annots', speaker_name)

video_names = [v for v in os.listdir(frames_path)]

# video_names = video_names[:2]

for vid_n in tqdm(video_names, desc=f"Processing {speaker_name}",
position=0, leave=False):

vid_folder = os.path.join(frames_path, vid_n)
# annots_folder = os.path.join(annots_path, vid_n)

sub_names = os.listdir(vid_folder)

for sub_name in tqdm(sub_names, desc=f"Processing {vid_n}",
position=1, leave=False):

# read annot
frame_folder = os.path.join(vid_folder, sub_name)
annot_folder = frame_folder.replace('talkshow_frames', 'talkshow_annots')
annot_path = os.path.join(annot_folder, f'{sub_name}.pkl')

if not os.path.exists(annot_path):
continue
# try:
# # read pickle
# with open(annot_path, 'rb') as f:
# annot = pickle.load(f)
# except:
# print(f"Error reading {vid_n}, {sub_name} pkl")
# pdb.set_trace()
# continue

# use GPU to load this, as tensor is saved in pkl
with open(annot_path, 'rb') as f:
annot = pickle.load(f)

frame_list = glob.glob(frame_folder.replace(' ','_')+'/*')
frame_list.sort()

# print(frame_list)
if len(frame_list) != annot['batch_size']:
print(f'Skip {vid_n}, {sub_name} due to size mismatch')
continue

# import ipdb;ipdb.set_trace()
annot = self._revert_smplx_hands_pca(annot,12)
self._modify_pose(speaker_name, annot)
# ipdb.set_trace()
# import ipdb;ipdb.set_trace()
jaw_pose = torch.Tensor(annot['jaw_pose']).to(self.device)

batch_size = len(jaw_pose)
betas = torch.Tensor(annot['betas']).to(self.device).repeat(batch_size,1)

leye_pose = torch.Tensor(annot['leye_pose']).to(self.device)
reye_pose = torch.Tensor(annot['reye_pose']).to(self.device)
global_orient = torch.Tensor(annot['global_orient']).squeeze().to(self.device)
body_pose = torch.Tensor(annot['body_pose_axis']).to(self.device)
left_hand_pose = torch.Tensor(annot['left_hand_pose']).to(self.device).reshape(-1,15,3)
right_hand_pose = torch.Tensor(annot['right_hand_pose']).to(self.device).reshape(-1,15,3)
transl = torch.Tensor(annot['transl']).to(self.device)
expression = torch.Tensor(annot['expression']).to(self.device) #B 100
# full_body = np.concatenate(
# (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1)
focal = annot['focal_length']
princpt = annot['center']
# princpt
K = np.array([
[focal, 0, princpt[0]],
[0, focal, princpt[1]],
[0, 0, 1]])
T = annot['camera_transl']
T[1] *= -1
T[0] *= -1
# import ipdb;ipdb.set_trace()
smplx_res = smplx_model(
betas=betas,
body_pose=body_pose,
global_orient=global_orient,
transl=transl,
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
jaw_pose=jaw_pose,
reye_pose=reye_pose,
leye_pose=leye_pose,
expression=expression,

pose2rot=True,
return_full_pose=True)
rotation = torch.eye(3)[None].cpu()
# ipdb.set_trace()
camera_transform = transform_mat(rotation,torch.Tensor(T).unsqueeze(dim=0).unsqueeze(dim=-1).cpu())
# for b_idx in range(len(global_orient)):
# import ipdb;ipdb.set_trace()

new_smplx_global_orient, new_smplx_transl = batch_transform_to_camera_frame(annot['global_orient'].squeeze(1),
annot['transl'],
smplx_res['joints'][:,0,:].cpu().detach().numpy(), (camera_transform[0]).cpu().numpy())
# import ipdb;ipdb.set_trace()
smplx_res_new = smplx_model(
betas=betas,
body_pose=body_pose,
global_orient=torch.Tensor(new_smplx_global_orient).to(self.device),
transl=torch.Tensor(new_smplx_transl).to(self.device),
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
jaw_pose=jaw_pose,
reye_pose=reye_pose,
leye_pose=leye_pose,
expression=expression,
pose2rot=True,
return_full_pose=True)


smplx_keypoints3d = smplx_res_new['joints'].clone()
# import ipdb;ipdb.set_trace()
smplx_keypoints3d_img = self._cam2pixel(smplx_keypoints3d.cpu().detach()*1000, focal, torch.Tensor(princpt))
smplx_keypoints2d = smplx_keypoints3d_img[:,:,:2]
smplx_pelvis = get_keypoint_idx('pelvis','smplx')
smplx_keypoints3d_root = smplx_keypoints3d - smplx_keypoints3d


for fid, imgp in enumerate(tqdm(frame_list, desc=f"Frames",
position=2, leave=False)):

# save image path
image_path = imgp.replace(f'{dataset_path}/','')
image_path_.append(image_path)

# save smplx
for k in smplx_.keys():
if k == 'global_orient' or k == 'transl' or k=='betas':
continue
if k=='body_pose':
smplx_[k].append(annot['body_pose_axis'][fid])
continue
smplx_[k].append(annot[k][fid])
smplx_['betas'].append(betas[fid,:10].cpu().detach().numpy())
smplx_['global_orient'].append(new_smplx_global_orient[fid])
smplx_['transl'].append(new_smplx_transl[fid])

# save bbox
bbox_tmp_ = {}
bbox_tmp_['bbox_xywh'],bbox_tmp_['face_bbox_xywh'], bbox_tmp_['lhand_bbox_xywh'], bbox_tmp_[
'rhand_bbox_xywh'] = self._keypoints_to_scaled_bbox_bfh(smplx_keypoints2d[fid], body_scale=1.2, fh_scale=1,convention='smplx')
for bbox_name in ['bbox_xywh', 'face_bbox_xywh', 'lhand_bbox_xywh', 'rhand_bbox_xywh']:
bbox = bbox_tmp_[bbox_name]
xmin, ymin, xmax, ymax = bbox[:4]
if bbox_name == 'bbox_xywh':
bbox_conf = 1
else:
bbox_conf = bbox[-1]
bbox = np.array([max(0, xmin), max(0, ymin), min(annot['width'], xmax), min(annot['height'], ymax)])
bbox_xywh = self._xyxy2xywh(bbox)
bbox_xywh.append(bbox_conf)
bboxs_[bbox_name].append(bbox_xywh)

# save meta
RT = np.eye(4)
RT[:3,3] = T
meta_['focal_length'].append(focal)
meta_['principal_point'].append(princpt)
meta_['height'].append(annot['height'])
meta_['width'].append(annot['width'])
meta_['sequence_name'].append(f'{vid_n}/{sub_name}')
meta_['RT'].append(RT)
meta_['gender'].append('neutral')

keypoints2d_.append(smplx_keypoints2d[fid])
keypoints3d_.append(smplx_keypoints3d[fid].cpu().detach().numpy())
# pdb.set_trace()

# get size
size_i = len(video_names)

# save keypoints 2d smplx
keypoints2d = np.concatenate(keypoints2d_, axis=0).reshape(-1, 144, 2)
keypoints2d_conf = np.ones([keypoints2d.shape[0], 144, 1])
keypoints2d = np.concatenate([keypoints2d, keypoints2d_conf], axis=-1)
keypoints2d, keypoints2d_mask = convert_kps(
keypoints2d, src='smplx', dst='human_data')
human_data['keypoints2d_smplx'] = keypoints2d
human_data['keypoints2d_smplx_mask'] = keypoints2d_mask

# save keypoints 3d smplx
keypoints3d = np.concatenate(keypoints3d_, axis=0).reshape(-1, 144, 3)
keypoints3d_conf = np.ones([keypoints3d.shape[0], 144, 1])
keypoints3d = np.concatenate([keypoints3d, keypoints3d_conf], axis=-1)
keypoints3d, keypoints3d_mask = convert_kps(
keypoints3d, src='smplx', dst='human_data')
human_data['keypoints3d_smplx'] = keypoints3d
human_data['keypoints3d_smplx_mask'] = keypoints3d_mask

# pdb.set_trace()
# save bbox
for bbox_name in [
'bbox_xywh', 'face_bbox_xywh', 'lhand_bbox_xywh',
'rhand_bbox_xywh'
]:
bbox_xywh_ = np.array(bboxs_[bbox_name]).reshape((-1, 5))
human_data[bbox_name] = bbox_xywh_

# save smplx
for key in smplx_.keys():
smplx_[key] = np.concatenate(
smplx_[key], axis=0).reshape(self.smplx_shape[key])

human_data['smplx'] = smplx_

# save image path
human_data['image_path'] = image_path_

# save contact
# human_data['contact'] = contact_

# save meta and misc
human_data['config'] = f'talkshow_{mode}'
human_data['misc'] = self.misc_config
human_data['meta'] = meta_

os.makedirs(out_path, exist_ok=True)
out_file = os.path.join(
# out_path, f'moyo_{self.misc_config["flat_hand_mean"]}.npz')
out_path, f'talkshow_{mode}_{seed}_{"{:03d}".format(size_i)}.npz')
human_data.dump(out_file)
return









# init
image_path_, bbox_xywh_, keypoints2d_smplx_, keypoints3d_smplx_, \
keypoints2d_smpl_, keypoints3d_smpl_, keypoints2d_ori_ \
Expand Down

0 comments on commit 764dc0d

Please sign in to comment.