Skip to content

Commit

Permalink
update tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
LizhenWangT committed Mar 30, 2022
1 parent 42d8333 commit 38e6ac6
Show file tree
Hide file tree
Showing 36 changed files with 5,419 additions and 0 deletions.
292 changes: 292 additions & 0 deletions model/FaceVerseModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
import torch
from torch import nn
import numpy as np

from model.ModelRenderer import ModelRenderer
from pytorch3d.structures import Meshes
from pytorch3d.renderer import TexturesVertex

class FaceVerseModel(nn.Module):
def __init__(self, model_dict, batch_size=1,
focal=1315, img_size=256, device='cuda:0'):
super(FaceVerseModel, self).__init__()

self.focal = focal
self.batch_size = batch_size
self.img_size = img_size
self.device = torch.device(device)

self.p_mat = self._get_p_mat(device)
self.reverse_z = self._get_reverse_z(device)
self.camera_pos = self._get_camera_pose(device)
self.rotXYZ = torch.eye(3).view(1, 3, 3).repeat(3, 1, 1).view(3, 1, 3, 3).to(self.device)

self.renderer = ModelRenderer(self.focal, self.img_size, self.device)

self.skinmask = torch.tensor(model_dict['skinmask'], requires_grad=False, device=self.device)

self.kp_inds = torch.tensor(model_dict['keypoints'].reshape(-1, 1), requires_grad=False).squeeze().long().to(self.device)

self.meanshape = torch.tensor(model_dict['meanshape'].reshape(1, -1), dtype=torch.float32, requires_grad=False, device=self.device)
self.meantex = torch.tensor(model_dict['meantex'].reshape(1, -1), dtype=torch.float32, requires_grad=False, device=self.device)

self.idBase = torch.tensor(model_dict['idBase'], dtype=torch.float32, requires_grad=False, device=self.device)
self.expBase = torch.tensor(model_dict['exBase'], dtype=torch.float32, requires_grad=False, device=self.device)
self.texBase = torch.tensor(model_dict['texBase'], dtype=torch.float32, requires_grad=False, device=self.device)

self.tri = torch.tensor(model_dict['tri'], dtype=torch.int64, requires_grad=False, device=self.device)
self.point_buf = torch.tensor(model_dict['point_buf'], dtype=torch.int64, requires_grad=False, device=self.device)

self.num_vertex = model_dict['meanshape'].reshape(-1, 3).shape[0]
self.id_dims = model_dict['idBase'].shape[1]
self.tex_dims = model_dict['texBase'].shape[1]
self.exp_dims = model_dict['exBase'].shape[1]
self.all_dims = self.id_dims + self.tex_dims + self.exp_dims

self.init_coeff_tensors()

# for tracking by landmarks
self.kp_inds_view = torch.cat([self.kp_inds[:, None] * 3, self.kp_inds[:, None] * 3 + 1, self.kp_inds[:, None] * 3 + 2], dim=1).flatten()
self.idBase_view = self.idBase[self.kp_inds_view, :].detach().clone()
self.expBase_view = self.expBase[self.kp_inds_view, :].detach().clone()
self.meanshape_view = self.meanshape[:, self.kp_inds_view].detach().clone()

def init_coeff_tensors(self):
self.id_tensor = torch.zeros(
(self.batch_size, self.id_dims), dtype=torch.float32,
requires_grad=True, device=self.device)

self.tex_tensor = torch.zeros(
(self.batch_size, self.tex_dims), dtype=torch.float32,
requires_grad=True, device=self.device)

self.exp_tensor = torch.zeros(
(self.batch_size, self.exp_dims), dtype=torch.float32,
requires_grad=True, device=self.device)

self.gamma_tensor = torch.zeros(
(self.batch_size, 27), dtype=torch.float32,
requires_grad=True, device=self.device)

self.trans_tensor = torch.zeros(
(self.batch_size, 3), dtype=torch.float32,
requires_grad=False, device=self.device)
self.trans_tensor[:, 2] += 6
self.trans_tensor.requires_grad = True

self.rot_tensor = torch.zeros(
(self.batch_size, 3), dtype=torch.float32,
requires_grad=False, device=self.device)
self.rot_tensor[:, 0] += torch.pi
self.rot_tensor.requires_grad = True

def get_lms(self, vs):
lms = vs[:, self.kp_inds, :]
return lms

def split_coeffs(self, coeffs):
id_coeff = coeffs[:, :self.id_dims] # identity(shape) coeff
exp_coeff = coeffs[:, self.id_dims:self.id_dims+self.exp_dims] # expression coeff
tex_coeff = coeffs[:, self.id_dims+self.exp_dims:self.all_dims] # texture(albedo) coeff
angles = coeffs[:, self.all_dims:self.all_dims+3] # ruler angles(x,y,z) for rotation of dim 3
gamma = coeffs[:, self.all_dims+3:self.all_dims+30] # lighting coeff for 3 channel SH function of dim 27
translation = coeffs[:, self.all_dims+30:] # translation coeff of dim 3

return id_coeff, exp_coeff, tex_coeff, angles, gamma, translation

def merge_coeffs(self, id_coeff, exp_coeff, tex_coeff, angles, gamma, translation):
coeffs = torch.cat([id_coeff, exp_coeff, tex_coeff, angles, gamma, translation], dim=1)
return coeffs

def get_packed_tensors(self):
return self.merge_coeffs(self.id_tensor.repeat(self.batch_size, 1),
self.exp_tensor,
self.tex_tensor.repeat(self.batch_size, 1),
self.rot_tensor, self.gamma_tensor,
self.trans_tensor)

def forward(self, coeffs, render=True, texture=True):
id_coeff, exp_coeff, tex_coeff, angles, gamma, translation = self.split_coeffs(coeffs)
rotation = self.compute_rotation_matrix(angles)

if render:
vs = self.get_vs(id_coeff, exp_coeff)
vs_t = self.rigid_transform(vs, rotation, translation)

lms_t = self.get_lms(vs_t)
lms_proj = self.project_vs(lms_t)
lms_proj = torch.stack(
[lms_proj[:, :, 0], self.img_size-lms_proj[:, :, 1]], dim=2)
face_texture = self.get_color(tex_coeff)
face_norm = self.compute_norm(vs, self.tri, self.point_buf)
face_norm_r = face_norm.bmm(rotation)
face_color = self.add_illumination(face_texture, face_norm_r, gamma)

if texture:
face_color_tv = TexturesVertex(face_color)
mesh = Meshes(vs_t, self.tri.repeat(self.batch_size, 1, 1), face_color_tv)
rendered_img = self.renderer.alb_renderer(mesh)
else:
face_color_tv = TexturesVertex(face_color * 0 + 200)
mesh = Meshes(vs_t, self.tri.repeat(self.batch_size, 1, 1), face_color_tv)
rendered_img = self.renderer.sha_renderer(mesh)

return {'rendered_img': rendered_img,
'lms_proj': lms_proj,
'face_texture': face_texture,
'vs': vs_t,
'tri': self.tri,
'color': face_color}
else:
lms = self.get_vs_lms(id_coeff, exp_coeff)
lms_t = self.rigid_transform(
lms, rotation, translation)

lms_proj = self.project_vs(lms_t)
lms_proj = torch.stack(
[lms_proj[:, :, 0], self.img_size-lms_proj[:, :, 1]], dim=2)
return {'lms_proj': lms_proj}

def get_vs(self, id_coeff, exp_coeff):
face_shape = torch.einsum('ij,aj->ai', self.idBase, id_coeff) + \
torch.einsum('ij,aj->ai', self.expBase, exp_coeff) + self.meanshape
face_shape = face_shape.view(self.batch_size, -1, 3)
return face_shape

def get_vs_lms(self, id_coeff, exp_coeff):
face_shape = torch.einsum('ij,aj->ai', self.idBase_view, id_coeff) + \
torch.einsum('ij,aj->ai', self.expBase_view, torch.abs(exp_coeff)) + self.meanshape_view
face_shape = face_shape.view(self.batch_size, -1, 3)
return face_shape

def get_color(self, tex_coeff):
face_texture = torch.einsum('ij,aj->ai', self.texBase, tex_coeff) + self.meantex
face_texture = face_texture.view(self.batch_size, -1, 3)
return face_texture

def get_skinmask(self):
return self.skinmask

def _get_camera_pose(self, device):
camera_pos = torch.tensor([0.0, 0.0, 10.0], device=device).reshape(1, 1, 3)
return camera_pos

def _get_p_mat(self, device):
half_image_width = self.img_size // 2
p_matrix = np.array([self.focal, 0.0, half_image_width,
0.0, self.focal, half_image_width,
0.0, 0.0, 1.0], dtype=np.float32).reshape(1, 3, 3)
return torch.tensor(p_matrix, device=device)

def _get_reverse_z(self, device):
reverse_z = np.reshape(np.array([1.0, 0, 0, 0, 1, 0, 0, 0, -1.0], dtype=np.float32), [1, 3, 3])
return torch.tensor(reverse_z, device=device)

def compute_norm(self, vs, tri, point_buf):
face_id = tri
point_id = point_buf
v1 = vs[:, face_id[:, 0], :]
v2 = vs[:, face_id[:, 1], :]
v3 = vs[:, face_id[:, 2], :]
e1 = v1 - v2
e2 = v2 - v3
face_norm = e1.cross(e2)

v_norm = face_norm[:, point_id, :].sum(2)
v_norm = v_norm / (v_norm.norm(dim=2).unsqueeze(2) + 1e-9)

return v_norm

def project_vs(self, vs):
vs = torch.matmul(vs, self.reverse_z.repeat((self.batch_size, 1, 1))) + self.camera_pos
aug_projection = torch.matmul(vs, self.p_mat.repeat((self.batch_size, 1, 1)).permute((0, 2, 1)))
face_projection = aug_projection[:, :, :2] / torch.reshape(aug_projection[:, :, 2], [self.batch_size, -1, 1])
return face_projection

def compute_rotation_matrix(self, angles):
sinx = torch.sin(angles[:, 0])
siny = torch.sin(angles[:, 1])
sinz = torch.sin(angles[:, 2])
cosx = torch.cos(angles[:, 0])
cosy = torch.cos(angles[:, 1])
cosz = torch.cos(angles[:, 2])

if self.batch_size != 1:
rotXYZ = self.rotXYZ.repeat(1, self.batch_size * 3, 1, 1)
else:
rotXYZ = self.rotXYZ.detach().clone()

rotXYZ[0, :, 1, 1] = cosx
rotXYZ[0, :, 1, 2] = -sinx
rotXYZ[0, :, 2, 1] = sinx
rotXYZ[0, :, 2, 2] = cosx
rotXYZ[1, :, 0, 0] = cosy
rotXYZ[1, :, 0, 2] = siny
rotXYZ[1, :, 2, 0] = -siny
rotXYZ[1, :, 2, 2] = cosy
rotXYZ[2, :, 0, 0] = cosz
rotXYZ[2, :, 0, 1] = -sinz
rotXYZ[2, :, 1, 0] = sinz
rotXYZ[2, :, 1, 1] = cosz

rotation = rotXYZ[2].bmm(rotXYZ[1]).bmm(rotXYZ[0])

return rotation.permute(0, 2, 1)

def add_illumination(self, face_texture, norm, gamma):
gamma = gamma.view(-1, 3, 9).clone()
gamma[:, :, 0] += 0.8
gamma = gamma.permute(0, 2, 1)

a0 = np.pi
a1 = 2 * np.pi / np.sqrt(3.0)
a2 = 2 * np.pi / np.sqrt(8.0)
c0 = 1 / np.sqrt(4 * np.pi)
c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi)
c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi)
d0 = 0.5 / np.sqrt(3.0)

norm = norm.view(-1, 3)
nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2]
arrH = []

arrH.append(a0 * c0 * (nx * 0 + 1))
arrH.append(-a1 * c1 * ny)
arrH.append(a1 * c1 * nz)
arrH.append(-a1 * c1 * nx)
arrH.append(a2 * c2 * nx * ny)
arrH.append(-a2 * c2 * ny * nz)
arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1))
arrH.append(-a2 * c2 * nx * nz)
arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2)))

H = torch.stack(arrH, 1)
Y = H.view(self.batch_size, self.num_vertex, 9)
lighting = Y.bmm(gamma)

face_color = face_texture * lighting
return face_color

def rigid_transform(self, vs, rot, trans):
vs_r = torch.matmul(vs, rot)
vs_t = vs_r + trans.view(-1, 1, 3)
return vs_t

def get_rot_tensor(self):
return self.rot_tensor

def get_trans_tensor(self):
return self.trans_tensor

def get_exp_tensor(self):
return self.exp_tensor

def get_tex_tensor(self):
return self.tex_tensor

def get_id_tensor(self):
return self.id_tensor

def get_gamma_tensor(self):
return self.gamma_tensor

60 changes: 60 additions & 0 deletions model/ModelRenderer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import numpy as np
import torch

from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
look_at_view_transform,
FoVPerspectiveCameras,
PointLights,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
HardFlatShader,
TexturesVertex,
blending
)

class ModelRenderer:
def __init__(self, focal=1315, img_size=224, device='cuda:0'):
self.img_size = img_size
self.focal = focal
self.device = device

self.alb_renderer = self._get_renderer(albedo=True)
self.sha_renderer = self._get_renderer(albedo=False)

def _get_renderer(self, albedo=True):
R, T = look_at_view_transform(10, 0, 0) # camera's position
cameras = FoVPerspectiveCameras(device=self.device, R=R, T=T, znear=0.01, zfar=50,
fov=2 * np.arctan(self.img_size // 2 / self.focal) * 180. / np.pi)

if albedo:
lights = PointLights(device=self.device, location=[[0.0, 0.0, 1e5]],
ambient_color=[[1, 1, 1]],
specular_color=[[0., 0., 0.]], diffuse_color=[[0., 0., 0.]])
else:
lights = PointLights(device=self.device, location=[[0.0, 0.0, 1e5]],
ambient_color=[[0.1, 0.1, 0.1]],
specular_color=[[0.0, 0.0, 0.0]], diffuse_color=[[0.95, 0.95, 0.95]])

raster_settings = RasterizationSettings(
image_size=self.img_size,
blur_radius=0.0,
faces_per_pixel=1,
)
blend_params = blending.BlendParams(background_color=[0, 0, 0])

renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=raster_settings
),
shader=HardFlatShader(
device=self.device,
cameras=cameras,
lights=lights,
blend_params=blend_params
)
)
return renderer

12 changes: 12 additions & 0 deletions model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from model.FaceVerseModel import FaceVerseModel
import numpy as np

def get_faceverse(version, **kargs):
if version == 0:
model_path = 'data/faceverse_base_v0.npy'
elif version == 1:
model_path = 'data/faceverse_base_v1.npy'
faceverse_dict = np.load(model_path, allow_pickle=True).item()
faceverse_model = FaceVerseModel(faceverse_dict, **kargs)
return faceverse_model, faceverse_dict

Loading

0 comments on commit 38e6ac6

Please sign in to comment.