Skip to content

Commit

Permalink
add version 2
Browse files Browse the repository at this point in the history
  • Loading branch information
LizhenWangT committed May 13, 2022
1 parent 7a958a3 commit f0530bc
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 23 deletions.
25 changes: 21 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ Please download the zip file of **version 0** or **version 1** (recommended) and

**Fig.3** Single-image reconstruction results of **version 1** (base model, detail model and expression refined final model).

**FaceVerse version 2** [[download]](https://drive.google.com/file/d/1_ooP9hvR7kUUO8WhtXRU_D4nM5fr8BT_/view?usp=sharing):

- Fit the expression components to the 52 blendshapes defined by Apple. Please check 'exp_name_list' in faceverse_simple_v2.npy for the mapping relation.

- Provide a simplification option (normal with 28632 vertices, simplified with 6335 vertices): you can use the selected points of FaceVerse v2 by:

```
python tracking_online.py --version 2 --use_simplification
python tracking_offline.py --input example/videos/test.mp4 --res_folder example/video_results --version 2 --use_simplification
```

- Refine the shape of the base PCA model: orthogonalization.

![v2](./docs/tracking_v2.gif)

**Fig.4** Real-time online tracking results (30 fps) of **version 2**. The real-time version is accelerated by point-base rendering using cuda and this version has not been released.

## Requirements

- Python 3.9
Expand Down Expand Up @@ -73,17 +90,17 @@ Note: the detailed refinement is based on differentiable rendering, which is qui

![offline_tracking](./docs/offline_tracking.gif)

Offline tracking input with a video (our code will crop the face region using the first frame):
Offline tracking input with a video (our code will crop the face region using the first frame, --use_simplification can be only used for version >= 2):

```
python tracking_offline.py --input example/videos/test.mp4 --res_folder example/video_results
python tracking_offline.py --input example/videos/test.mp4 --res_folder example/video_results --version 2
```


Online tracking using your PC camera (our code will crop the face region using the first frame):
Online tracking using your PC camera (our code will crop the face region using the first frame, --use_simplification can be only used for version >= 2):

```
python tracking_online.py
python tracking_online.py --version 2
```

![online_tracking](./docs/online_tracking.gif)
Expand Down
8 changes: 7 additions & 1 deletion data/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
### FaceVerse version 0
id_dim 120, exp_dim 64, tex_dim 200
base uv in 256 x 256 (not symmetrical)
base uv in 200 x 200 (not symmetrical)

### FaceVerse version 1
id_dim 150, exp_dim 74, tex_dim 251
base uv in 199 x 199 (symmetrical)

### FaceVerse version 2
id_dim 150, exp_dim 52, tex_dim 251
base uv in 199 x 199 (symmetrical)
normal version: 28632 vertices
simplified version: 6335 vertices


46 changes: 32 additions & 14 deletions model/FaceVerseModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class FaceVerseModel(nn.Module):
def __init__(self, model_dict, batch_size=1,
focal=1315, img_size=256, device='cuda:0'):
focal=1315, img_size=256, use_simplification=False, device='cuda:0'):
super(FaceVerseModel, self).__init__()

self.focal = focal
Expand All @@ -23,24 +23,42 @@ def __init__(self, model_dict, batch_size=1,

self.renderer = ModelRenderer(self.focal, self.img_size, self.device)

self.skinmask = torch.tensor(model_dict['skinmask'], requires_grad=False, device=self.device)
if use_simplification:
self.select_id = model_dict['select_id']
self.select_id_tris = np.vstack((self.select_id * 3, self.select_id * 3 + 1, self.select_id * 3 + 2)).transpose().flatten()
self.skinmask = torch.tensor(model_dict['skinmask_select'], requires_grad=False, device=self.device)

self.kp_inds = torch.tensor(model_dict['keypoints'].reshape(-1, 1), requires_grad=False).squeeze().long().to(self.device)
self.kp_inds = torch.tensor(model_dict['keypoints_select'].reshape(-1, 1), requires_grad=False).squeeze().long().to(self.device)

self.meanshape = torch.tensor(model_dict['meanshape'].reshape(1, -1), dtype=torch.float32, requires_grad=False, device=self.device)
self.meantex = torch.tensor(model_dict['meantex'].reshape(1, -1), dtype=torch.float32, requires_grad=False, device=self.device)
self.meanshape = torch.tensor(model_dict['meanshape'].reshape(1, -1)[:, self.select_id_tris], dtype=torch.float32, requires_grad=False, device=self.device)
self.meantex = torch.tensor(model_dict['meantex'].reshape(1, -1)[:, self.select_id_tris], dtype=torch.float32, requires_grad=False, device=self.device)

self.idBase = torch.tensor(model_dict['idBase'], dtype=torch.float32, requires_grad=False, device=self.device)
self.expBase = torch.tensor(model_dict['exBase'], dtype=torch.float32, requires_grad=False, device=self.device)
self.texBase = torch.tensor(model_dict['texBase'], dtype=torch.float32, requires_grad=False, device=self.device)
self.idBase = torch.tensor(model_dict['idBase'][self.select_id_tris], dtype=torch.float32, requires_grad=False, device=self.device)
self.expBase = torch.tensor(model_dict['exBase'][self.select_id_tris], dtype=torch.float32, requires_grad=False, device=self.device)
self.texBase = torch.tensor(model_dict['texBase'][self.select_id_tris], dtype=torch.float32, requires_grad=False, device=self.device)

self.tri = torch.tensor(model_dict['tri'], dtype=torch.int64, requires_grad=False, device=self.device)
self.point_buf = torch.tensor(model_dict['point_buf'], dtype=torch.int64, requires_grad=False, device=self.device)
self.tri = torch.tensor(model_dict['tri_select'], dtype=torch.int64, requires_grad=False, device=self.device)
self.point_buf = torch.tensor(model_dict['point_buf_select'], dtype=torch.int64, requires_grad=False, device=self.device)

else:
self.skinmask = torch.tensor(model_dict['skinmask'], requires_grad=False, device=self.device)

self.kp_inds = torch.tensor(model_dict['keypoints'].reshape(-1, 1), requires_grad=False).squeeze().long().to(self.device)

self.meanshape = torch.tensor(model_dict['meanshape'].reshape(1, -1), dtype=torch.float32, requires_grad=False, device=self.device)
self.meantex = torch.tensor(model_dict['meantex'].reshape(1, -1), dtype=torch.float32, requires_grad=False, device=self.device)

self.idBase = torch.tensor(model_dict['idBase'], dtype=torch.float32, requires_grad=False, device=self.device)
self.expBase = torch.tensor(model_dict['exBase'], dtype=torch.float32, requires_grad=False, device=self.device)
self.texBase = torch.tensor(model_dict['texBase'], dtype=torch.float32, requires_grad=False, device=self.device)

self.tri = torch.tensor(model_dict['tri'], dtype=torch.int64, requires_grad=False, device=self.device)
self.point_buf = torch.tensor(model_dict['point_buf'], dtype=torch.int64, requires_grad=False, device=self.device)

self.num_vertex = model_dict['meanshape'].reshape(-1, 3).shape[0]
self.id_dims = model_dict['idBase'].shape[1]
self.tex_dims = model_dict['texBase'].shape[1]
self.exp_dims = model_dict['exBase'].shape[1]
self.num_vertex = self.meanshape.shape[1] // 3
self.id_dims = self.idBase.shape[1]
self.tex_dims = self.texBase.shape[1]
self.exp_dims = self.expBase.shape[1]
self.all_dims = self.id_dims + self.tex_dims + self.exp_dims

self.init_coeff_tensors()
Expand Down
2 changes: 2 additions & 0 deletions model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ def get_faceverse(version, **kargs):
model_path = 'data/faceverse_base_v0.npy'
elif version == 1:
model_path = 'data/faceverse_base_v1.npy'
elif version == 2:
model_path = 'data/faceverse_simple_v2.npy'
faceverse_dict = np.load(model_path, allow_pickle=True).item()
faceverse_model = FaceVerseModel(faceverse_dict, **kargs)
return faceverse_model, faceverse_dict
Expand Down
14 changes: 12 additions & 2 deletions tracking_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def init_optim_with_id(args, faceverse_model):


def tracking(args, device):
faceverse_model, faceverse_dict = get_faceverse(version=args.version, batch_size=1, focal=1315, img_size=args.tar_size, device=device)
faceverse_model, faceverse_dict = get_faceverse(version=args.version, batch_size=1, focal=1315, img_size=args.tar_size, use_simplification=args.use_simplification, device=device)
lm_weights = losses.get_lm_weights(device)
offreader = OfflineReader(args.input)
print(args.input, 'FPS:', offreader.fps)
Expand Down Expand Up @@ -81,6 +81,10 @@ def tracking(args, device):

total_loss.backward()
rigid_optimizer.step()

if args.version == 2:
with torch.no_grad():
faceverse_model.exp_tensor[faceverse_model.exp_tensor < 0] *= 0

# fitting with differentiable rendering
for i in range(num_iters_nrf):
Expand All @@ -105,6 +109,10 @@ def tracking(args, device):
loss.backward()
nonrigid_optimizer.step()

if args.version == 2:
with torch.no_grad():
faceverse_model.exp_tensor[faceverse_model.exp_tensor < 0] *= 0

# save data
with torch.no_grad():
pred_dict = faceverse_model(faceverse_model.get_packed_tensors(), render=True, texture=True)
Expand Down Expand Up @@ -143,13 +151,15 @@ def tracking(args, device):

parser.add_argument('--input', type=str, required=True,
help='input video path')
parser.add_argument('--use_simplification', action='store_true',
help='use the simplified FaceVerse model.')
parser.add_argument('--res_folder', type=str, required=True,
help='output directory')
parser.add_argument('--save_ply', action="store_true",
help='save the output ply or not')
parser.add_argument('--save_coeff', action="store_true",
help='save the output coeff or not')
parser.add_argument('--version', type=int, default=1,
parser.add_argument('--version', type=int, default=2,
help='FaceVerse model version.')
parser.add_argument('--tar_size', type=int, default=512,
help='size for rendering window. We use a square window.')
Expand Down
14 changes: 12 additions & 2 deletions tracking_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def init_optim_with_id(args, faceverse_model):


def tracking(args, device):
faceverse_model, faceverse_dict = get_faceverse(version=args.version, batch_size=1, focal=1315, img_size=args.tar_size, device=device)
faceverse_model, faceverse_dict = get_faceverse(version=args.version, batch_size=1, focal=1315, img_size=args.tar_size, use_simplification=args.use_simplification, device=device)
lm_weights = losses.get_lm_weights(device)
onreader = OnlineReader(camera_id=0, width=1920, height=1080)
onreader.start()
Expand Down Expand Up @@ -75,6 +75,10 @@ def tracking(args, device):

total_loss.backward()
rigid_optimizer.step()

if args.version == 2:
with torch.no_grad():
faceverse_model.exp_tensor[faceverse_model.exp_tensor < 0] *= 0

# fitting with differentiable rendering
for i in range(num_iters_nrf):
Expand All @@ -98,6 +102,10 @@ def tracking(args, device):

loss.backward()
nonrigid_optimizer.step()

if args.version == 2:
with torch.no_grad():
faceverse_model.exp_tensor[faceverse_model.exp_tensor < 0] *= 0

# show data
with torch.no_grad():
Expand Down Expand Up @@ -146,8 +154,10 @@ def tracking(args, device):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="FaceVerse online tracker")

parser.add_argument('--version', type=int, default=1,
parser.add_argument('--version', type=int, default=2,
help='FaceVerse model version.')
parser.add_argument('--use_simplification', action='store_true',
help='use the simplified FaceVerse model.')
parser.add_argument('--tar_size', type=int, default=512,
help='size for rendering window. We use a square window.')
parser.add_argument('--padding_ratio', type=float, default=1.0,
Expand Down

0 comments on commit f0530bc

Please sign in to comment.