Skip to content

Commit

Permalink
regular update
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei-Chen-hub committed Oct 17, 2024
1 parent c64c1ef commit bdb973d
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 35 deletions.
116 changes: 104 additions & 12 deletions mmhuman3d/data/data_converters/mpii_neural_annot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from .base_converter import BaseModeConverter
from .builder import DATA_CONVERTERS

from pycocotools.coco import COCO

import pdb


Expand Down Expand Up @@ -60,6 +62,68 @@ def __init__(self, modes: List = []) -> None:
super(MpiiNeuralConverter, self).__init__(modes)


def _keypoints_to_scaled_bbox_bfh(self,
keypoints,
occ=None,
body_scale=1.0,
fh_scale=1.0,
convention='smplx'):
'''Obtain scaled bbox in xyxy format given keypoints
Args:
keypoints (np.ndarray): Keypoints
scale (float): Bounding Box scale
Returns:
bbox_xyxy (np.ndarray): Bounding box in xyxy format
'''
bboxs = []

# supported kps.shape: (1, n, k) or (n, k), k = 2 or 3
if keypoints.ndim == 3:
keypoints = keypoints[0]
if keypoints.shape[-1] != 2:
keypoints = keypoints[:, :2]

for body_part in ['body', 'head', 'left_hand', 'right_hand']:
if body_part == 'body':
scale = body_scale
kps = keypoints
else:
scale = fh_scale
kp_id = get_keypoint_idxs_by_part(
body_part, convention=convention)
kps = keypoints[kp_id]

if occ is not None:
occ_p = occ[kp_id]
if np.sum(occ_p) / len(kp_id) >= 0.1:
conf = 0
else:
conf = 1
else:
conf = 1
if body_part == 'body':
conf = 1

xmin, ymin = np.amin(kps, axis=0)
xmax, ymax = np.amax(kps, axis=0)

width = (xmax - xmin) * scale
height = (ymax - ymin) * scale

x_center = 0.5 * (xmax + xmin)
y_center = 0.5 * (ymax + ymin)
xmin = x_center - 0.5 * width
xmax = x_center + 0.5 * width
ymin = y_center - 0.5 * height
ymax = y_center + 0.5 * height

bbox = np.stack([xmin, ymin, xmax, ymax, conf],
axis=0).astype(np.float32)
bboxs.append(bbox)

return bboxs


def convert_by_mode(self,
dataset_path: str,
out_path: str,
Expand All @@ -86,28 +150,33 @@ def convert_by_mode(self,
keypoints2d_smplx_, keypoints3d_smplx_, = [], []
keypoints2d_orig_ = [ ]
bboxs_ = {}
for bbox_name in ['bbox_xywh']:
for bbox_name in [
'bbox_xywh', 'face_bbox_xywh', 'lhand_bbox_xywh',
'rhand_bbox_xywh'
]:
bboxs_[bbox_name] = []
meta_ = {}
for key in ['focal_length', 'principal_point', 'height', 'width']:
meta_[key] = []
image_path_ = []

# load data seperate
split_path = os.path.join(dataset_path, 'annotations', f'{mode}_reformat.json')
with open(split_path, 'r') as f:
image_data = json.load(f)
split_path = os.path.join(dataset_path, 'annotations', f'{mode}.json')
db = COCO(split_path)

# with open(split_path, 'r') as f:
# image_data = json.load(f)

# load smplx annot
smplx_path = os.path.join(dataset_path, 'annotations', f'MPII_train_SMPLX_NeuralAnnot.json')
with open(smplx_path, 'r') as f:
smplx_data = json.load(f)

# get targeted frame list
image_list = list(image_data.keys())
image_list = list(db.anns.keys())

# init seed and size
seed, size = '230814', '90999'
seed, size = '231016', '90999'
size_i = min(int(size), len(image_list))
random.seed(int(seed))
image_list = image_list[:size_i]
Expand All @@ -127,19 +196,25 @@ def convert_by_mode(self,
use_pca=False,
batch_size=1)).to(self.device)

for fname in tqdm(image_list, desc=f'Converting MPII {mode} data'):
for aid in tqdm(db.anns.keys(), desc=f'Converting MPII {mode} data'):
# for aid in db.anns.keys():

# get info slice
image_info = image_data[fname]
ann = db.anns[aid]
image_info = ann
img = db.loadImgs(ann['image_id'])[0]
# pdb.set_trace()

# prepare image path
image_path = os.path.join('images', f'{fname}')
image_path = img['file_name']
imgp = os.path.join(dataset_path, image_path)
if not os.path.exists(imgp):
continue

# access image info
annot_id = image_info['id']
width = image_info['width']
height = image_info['height']
width = img['width']
height = img['height']

# read keypoints2d and bbox
j2d = np.array(image_info['keypoints']).reshape(-1, 3)
Expand Down Expand Up @@ -210,7 +285,24 @@ def convert_by_mode(self,
keypoints2d_orig_.append(j2d)

# append bbox
bboxs_['bbox_xywh'].append(bbox_xywh)
bboxs = self._keypoints_to_scaled_bbox_bfh(
keypoints_2d,
body_scale=1.2,
fh_scale=1.0)
for i, bbox_name in enumerate([
'bbox_xywh', 'face_bbox_xywh', 'lhand_bbox_xywh',
'rhand_bbox_xywh'
]):
xmin, ymin, xmax, ymax, conf = bboxs[i]
bbox = np.array([
max(0, xmin),
max(0, ymin),
min(width, xmax),
min(height, ymax)
])
bbox_xywh = self._xyxy2xywh(bbox) # list of len 4
bbox_xywh.append(conf) # (5,)
bboxs_[bbox_name].append(bbox_xywh)

# append smpl
for key in smplx_param.keys():
Expand Down
10 changes: 7 additions & 3 deletions mmhuman3d/data/data_converters/mscoco_neural_annot.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def convert_by_mode(self,
bboxs_[bbox_name] = []
meta_ = {}
for key in ['focal_length', 'principal_point', 'height', 'width',
'lefthand_valid', 'righthand_valid', 'face_valid']:
'lefthand_valid', 'righthand_valid', 'face_valid',
'iscrowd', 'num_keypoints']:
meta_[key] = []
image_path_ = []

Expand Down Expand Up @@ -265,8 +266,11 @@ def convert_by_mode(self,
# append meta
meta_['principal_point'].append(principal_point)
meta_['focal_length'].append(focal_length)
meta_['height'].append(height)
meta_['width'].append(width)

for key in ['height', 'width',
'lefthand_valid', 'righthand_valid', 'face_valid',
'iscrowd', 'num_keypoints']:
meta_[key].append(info_anno[key])

# extra smplx params
smplx_extra_['left_hand_valid'].append(lefthand_valid)
Expand Down
70 changes: 53 additions & 17 deletions mmhuman3d/data/data_converters/pw3d_bedlam.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def __init__(self, modes: List = []):
bbox_source='keypoints2d_smplx',
smpl_source='original',
cam_param_type='prespective',
bbox_scale=1.2,
bbox_body_scale=1.2,
bbox_facehand_scale=1.0,
kps3d_root_aligned=False,
flat_hand_mean=False,
has_gender=True,
Expand Down Expand Up @@ -133,7 +134,7 @@ def convert_by_mode(self,
keypoints2d_smplx_, keypoints3d_smplx_, = [], []
keypoints2d_orig_ = []
bboxs_ = {}
for bbox_name in ['bbox_xywh']:
for bbox_name in ['bbox_xywh', 'face_bbox_xywh', 'lhand_bbox_xywh', 'rhand_bbox_xywh']:
bboxs_[bbox_name] = []
meta_ = {}
for key in ['focal_length', 'principal_point', 'height', 'width']:
Expand Down Expand Up @@ -169,19 +170,21 @@ def convert_by_mode(self,
random.seed(int(seed))
targeted_frame_ids = targeted_frame_ids[:size_i]

# init smplx model
smplx_model = build_body_model(
dict(
type='SMPLX',
keypoint_src='smplx',
keypoint_dst='smplx',
model_path='data/body_models/smplx',
gender='neutral',
num_betas=10,
use_face_contour=True,
flat_hand_mean=False,
use_pca=False,
batch_size=1)).to(self.device)
# init gendered smplx model
smplx_model_dict = {}
for gender in ['male', 'female', 'neutral']:
smplx_model_dict[gender] = build_body_model(
dict(
type='SMPLX',
keypoint_src='smplx',
keypoint_dst='smplx',
model_path='data/body_models/smplx',
gender=gender,
num_betas=10,
use_face_contour=True,
flat_hand_mean=False,
use_pca=False,
batch_size=1)).to(self.device)

print('Converting...')
for sid in tqdm(targeted_frame_ids):
Expand Down Expand Up @@ -214,6 +217,8 @@ def convert_by_mode(self,

bedlam_pose = annot_param['smplx_pose'][aid].reshape(-1, 3)[1:22, :]

pdb.set_trace()

# cal error
loss = torch.abs(torch.tensor(aa_to_rotmat(bedlam_pose)) -
torch.tensor(aa_to_rotmat(neural_pose)))
Expand All @@ -228,6 +233,14 @@ def convert_by_mode(self,

# select one with lowest loss
aid = aids[np.argmin(losses)]
gender = annot_param['gender'][aid]
if gender == 'm':
gender = 'male'
elif gender == 'f':
gender = 'female'
elif gender == 'n':
gender = 'neutral'

bedlam_betas = annot_param['smplx_shape'][aid][:10].reshape(1, 10)
bedlam_global_orient = annot_param['smplx_pose'][aid].reshape(-1, 3)[0:1, :]
bedlam_pose = annot_param['smplx_pose'][aid].reshape(-1, 3)[1:22, :]
Expand Down Expand Up @@ -268,6 +281,7 @@ def convert_by_mode(self,
np.array(smplx_param[key]).reshape(self.smplx_shape[key]),
device=self.device, dtype=torch.float32)
for key in intersect_keys}
smplx_model = smplx_model_dict['neutral']
output = smplx_model(**body_model_param_tensor, return_joints=True)

# get kps2d and 3d
Expand Down Expand Up @@ -304,8 +318,30 @@ def convert_by_mode(self,
keypoints3d_smplx_.append(keypoints_3d)
keypoints2d_orig_.append(j2d)

# append bbox
bboxs_['bbox_xywh'].append(bbox_xywh)
# # append bbox
# bboxs_['bbox_xywh'].append(bbox_xywh)

# get bbox from 2d keypoints
bboxs = self._keypoints_to_scaled_bbox_bfh(
keypoints_2d,
body_scale=self.misc_config['bbox_body_scale'],
fh_scale=self.misc_config['bbox_facehand_scale'])
## convert xyxy to xywh
for i, bbox_name in enumerate([
'bbox_xywh', 'face_bbox_xywh',
'lhand_bbox_xywh', 'rhand_bbox_xywh'
]):
xmin, ymin, xmax, ymax, conf = bboxs[i]
bbox = np.array([
max(0, xmin),
max(0, ymin),
min(width, xmax),
min(height, ymax)
])
bbox_xywh = self._xyxy2xywh(bbox) # list of len 4
bbox_xywh.append(conf) # (5,)
bboxs_[bbox_name].append(bbox_xywh)


# append smpl
for key in smplx_param.keys():
Expand Down
Loading

0 comments on commit bdb973d

Please sign in to comment.