diff --git a/configs/pymafx/pymafx.py b/configs/pymafx/pymafx.py index ef2d5034..f684c52b 100644 --- a/configs/pymafx/pymafx.py +++ b/configs/pymafx/pymafx.py @@ -1,4 +1,8 @@ -_base_ = ['../_base_/default_runtime.py'] +mmhuman3d_data_path = 'data' +mmhuman3d_config_path = 'configs' + +_base_ = [f'../../{mmhuman3d_config_path}/_base_/default_runtime.py'] + maf_on = True __bhf_mode__ = 'full_body' # full_body or body_hand __grid_align__ = dict( @@ -61,7 +65,7 @@ __mesh_model__ = dict( name='smplx', - smpl_mean_params='data/body_models/smpl_mean_params.npz', + smpl_mean_params=f'{mmhuman3d_data_path}/body_models/smpl_mean_params.npz', gender='neutral') model = dict( @@ -78,6 +82,7 @@ bhf_mode=__bhf_mode__, grid_feat=False, grid_align=__grid_align__, + mmhuman3d_data_path=mmhuman3d_data_path, ), regressor=dict( type='Regressor', @@ -85,13 +90,15 @@ bhf_mode=__bhf_mode__, use_iwp_cam=True, n_iter=3, - smpl_model_dir='data/body_models/smpl', + smpl_model_dir=f'{mmhuman3d_data_path}/body_models/smpl', smpl_mean_params=__mesh_model__['smpl_mean_params'], ), - attention_config='configs/pymafx/bert_base_uncased_config.py', - extra_joints_regressor='data/body_models/J_regressor_extra.npy', - smplx_to_smpl='data/body_models/smplx/smplx_to_smpl.npz', - smplx_model_dir='data/body_models/smplx', + mmhuman3d_data_path=mmhuman3d_data_path, + attention_config=f'{mmhuman3d_config_path}/pymafx/bert_base_uncased_config.py', + extra_joints_regressor=f'{mmhuman3d_data_path}/body_models/J_regressor_extra.npy', + smplx_to_smpl=f'{mmhuman3d_data_path}/body_models/smplx/smplx_to_smpl.npz', + smplx_model_dir=f'{mmhuman3d_data_path}/body_models/smplx', + partial_mesh_path=f'{mmhuman3d_data_path}/partial_mesh', mesh_model=__mesh_model__, bhf_mode=__bhf_mode__, maf_on=maf_on, diff --git a/mmhuman3d/models/architectures/pymafx.py b/mmhuman3d/models/architectures/pymafx.py index 16db3463..687c8543 100644 --- a/mmhuman3d/models/architectures/pymafx.py +++ b/mmhuman3d/models/architectures/pymafx.py @@ -90,6 +90,8 @@ def __init__(self, extra_joints_regressor: str, smplx_to_smpl: str, smplx_model_dir: str, + partial_mesh_path: str, + mmhuman3d_data_path: str, mesh_model: dict, bhf_mode: str, maf_on: bool, @@ -112,6 +114,8 @@ def __init__(self, self.use_iwp_cam = use_iwp_cam self.backbone = backbone self.smplx_model_dir = smplx_model_dir + self.partial_mesh_path = partial_mesh_path + self.mmhuman3d_data_path = mmhuman3d_data_path self.smplx_to_smpl = smplx_to_smpl self.hf_model_cfg = hf_model_cfg self.mesh_model = mesh_model @@ -183,7 +187,7 @@ def _create_encoder(self): if 'body' in self.bhf_names and self.backbone is not None: self.encoders['body'] = build_backbone(self.backbone) - self.mesh_sampler = Mesh_Sampler(type='smpl') + self.mesh_sampler = Mesh_Sampler(type='smpl', data_path=self.mmhuman3d_data_path) if not self.grid_feat: self.ma_feat_dim = self.mesh_sampler.Dmap.shape[ 0] * self.mlp_dim[-1] @@ -225,7 +229,8 @@ def _create_attention_modules(self, attention_config): hidden_feat_dim = self.mlp_dim[self.att_feat_dim_idx] self.bhf_att_feat_dim.update({'body': 2048}) if 'hand' in self.bhf_names: - self.mano_sampler = Mesh_Sampler(type='mano', level=1) + self.mano_sampler = Mesh_Sampler(type='mano', level=1, + data_path=self.mmhuman3d_data_path) self.mano_ds_len = self.mano_sampler.Dmap.shape[0] self.bhf_ma_feat_dim.update( @@ -353,12 +358,14 @@ def _create_maf_extractor(self): MAF_Extractor( filter_channels=filter_channels_default[ self.att_feat_dim_idx:], - iwp_cam_mode=self.use_iwp_cam)) + iwp_cam_mode=self.use_iwp_cam, + data_path=self.mmhuman3d_data_path)) else: self.maf_extractor[part].append( MAF_Extractor( filter_channels=filter_channels, - iwp_cam_mode=self.use_iwp_cam)) + iwp_cam_mode=self.use_iwp_cam, + data_path=self.mmhuman3d_data_path)) def forward_train(self, **kwargs): """Forward function for general training. diff --git a/mmhuman3d/models/body_models/smplx.py b/mmhuman3d/models/body_models/smplx.py index f018db01..a3e8c097 100644 --- a/mmhuman3d/models/body_models/smplx.py +++ b/mmhuman3d/models/body_models/smplx.py @@ -396,7 +396,7 @@ def __init__(self, keypoint_approximate: bool = True, extra_joints_regressor: str = None, smplx_to_smpl: str = None, - partial_mesh_path: str = 'data/partial_mesh/', + partial_mesh_path: str = None, batch_size: int = 1, use_face_contour: bool = True, **kwargs) -> None: @@ -472,6 +472,7 @@ def __init__(self, torch.tensor( smplx_to_smpl['matrix'][None], dtype=torch.float32)) if partial_mesh_path is not None: + import pdb; pdb.set_trace() smpl2limb_vert_faces = get_partial_smpl(partial_mesh_path) self.smpl2lhand = torch.from_numpy( smpl2limb_vert_faces['lhand']['vids']).long() diff --git a/mmhuman3d/models/heads/pymafx_head.py b/mmhuman3d/models/heads/pymafx_head.py index d925b233..2b4bc1db 100644 --- a/mmhuman3d/models/heads/pymafx_head.py +++ b/mmhuman3d/models/heads/pymafx_head.py @@ -53,6 +53,7 @@ class Mesh_Sampler(nn.Module): def __init__(self, type='smpl', + data_path='data', level=2, device=torch.device('cuda'), option=None): @@ -61,7 +62,7 @@ def __init__(self, # downsample SMPL mesh and assign part labels if type == 'smpl': smpl_mesh_graph = np.load( - 'data/smpl_downsampling.npz', + f'{data_path}/smpl_downsampling.npz', allow_pickle=True, encoding='latin1') @@ -70,7 +71,7 @@ def __init__(self, elif type == 'mano': # TODO: replace path mano_mesh_graph = np.load( - 'data/mano_downsampling.npz', + f'{data_path}/mano_downsampling.npz', allow_pickle=True, encoding='latin1') @@ -143,6 +144,7 @@ class MAF_Extractor(nn.Module): def __init__(self, filter_channels, + data_path='data', device=torch.device('cuda'), iwp_cam_mode=True, option=None): @@ -169,7 +171,7 @@ def __init__(self, # downsample SMPL mesh and assign part labels # https://github.com/nkolot/GraphCMR/blob/master/data/mesh_downsampling.npz smpl_mesh_graph = np.load( - 'data/smpl_downsampling.npz', allow_pickle=True, encoding='latin1') + f'{data_path}/smpl_downsampling.npz', allow_pickle=True, encoding='latin1') U = smpl_mesh_graph['U'] D = smpl_mesh_graph['D'] # shape: (2,) @@ -1451,6 +1453,7 @@ def __init__(self, bhf_names, hf_root_idx, mano_ds_len, + mmhuman3d_data_path=None, grid_feat=False, hf_box_center=True, use_iwp_cam=True, @@ -1466,12 +1469,14 @@ def __init__(self, self.grid_align = grid_align self.bhf_names = bhf_names self.opt_wrist = True - self.mano_sampler = Mesh_Sampler(type='mano', level=1) - self.mesh_sampler = Mesh_Sampler(type='smpl') + self.mano_sampler = Mesh_Sampler(type='mano', level=1, + data_path=mmhuman3d_data_path) + self.mesh_sampler = Mesh_Sampler(type='smpl', + data_path=mmhuman3d_data_path) self.init_mesh_output = None self.batch_size = 1 self.n_iter = n_iter - smpl2limb_vert_faces = get_partial_smpl() + smpl2limb_vert_faces = get_partial_smpl(f'{mmhuman3d_data_path}/partial_mesh') self.smpl2lhand = torch.from_numpy( smpl2limb_vert_faces['lhand']['vids']).long() self.smpl2rhand = torch.from_numpy(