Skip to content

Commit

Permalink
Ref NeRF good to go
Browse files Browse the repository at this point in the history
  • Loading branch information
Enigmatisms committed Aug 24, 2022
1 parent 0dd5741 commit 0c482d0
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 31 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ If you are interested in the implementation of NeRFs and you don't want to read
- If you are interested in CUDA implementations (there were, once), please refer to: [issue#4](https://github.com/Enigmatisms/NeRF/issues/4) and [issue#6](https://github.com/Enigmatisms/NeRF/issues/6)

Some Ref NeRF results:
(Latest-commit `e4907564`) Shinny blender "helmet" dataset trained for 3 hours (not completed, PSNR around 18). Oops, gif file to big to upload. Fine, just imagine the output, its better than the older commit (in terms of normal prediction)
(Latest-commit `e4907564`) Shinny blender "helmet" dataset trained for 3 hours (not completed, PSNR around 27). Oops, gif file to big to upload. Fine, just imagine the output, its better than the older commit (in terms of normal prediction)

(Older-commit `847fdb9d`) Shinny blender "helmet" dataset trained for 6-7 hours (not completed, PSNR around 19.5)
(Older-commit `847fdb9d`) Shinny blender "helmet" dataset trained for 6-7 hours (not completed, PSNR around 28.5.)
![ezgif-1-8207b1faa2](https://user-images.githubusercontent.com/46109954/185753069-d5cbd05e-1f66-4423-9503-1a5cd126ed89.gif)


Expand Down
4 changes: 4 additions & 0 deletions Updates.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

---

### 8.24 Update

It turns out that PSNR calculation is to blame. My loss function is not MSE (PSNR is calculated using MSE), it is SoftL1 (`sqrt(e^2 + ε)`), which is bigger than expected. Therefore "PSNR" is low (around 19.). Currently, PSNR of the model's (trained for only 7 hours) is around 28.5.

### 8.19 Update

CVPR 2022 best student honorable mention: [Ref-NeRF: Structured View-Dependent Appearance for Neural Radiance Fields](https://arxiv.org/abs/2112.03907) is implemented in this repo. This repo can turn Ref NeRF part on/off with one flag: `-t`. Ref NeRF is implemented upon (proposal network + NeRF) framework. Currently, the result is not so satisfying as I expected. This may be caused by insufficient time for training (limited training device, 6GB global memory, can only use up to batch size 2^9 (rays), while the paper uses 2^14).
Expand Down
14 changes: 14 additions & 0 deletions mkdir.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
if [ ! -d output ]; then
echo "Creating folder \"output\" and \"output/sphere\"";
mkdir -p output/sphere
fi

if [ ! -d check_points ]; then
echo "Creating folder \"check_points\"";
mkdir check_points
fi

if [ ! -d model ]; then
echo "Creating folder \"model\"";
mkdir model
fi
2 changes: 1 addition & 1 deletion py/addtional.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, epsilon = 0.001) -> None:
super().__init__()
self.eps = epsilon
def forward(self, pred:torch.Tensor, target:torch.Tensor):
return torch.mean(torch.sqrt(self.eps ** 2 + (pred - target) ** 2))
return torch.mean((pred - target) ** 2)

class LossPSNR(nn.Module):
__LOG_10__ = 2.3025851249694824
Expand Down
14 changes: 11 additions & 3 deletions py/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def forward(self, input:Image.Image):
return F.resize(input, size, self.interpolation, self.max_size, self.antialias)

class CustomDataSet(data.Dataset):
def __init__(self, root_dir, transform, scene_scale = 1.0, is_train = True, use_alpha = False, white_bkg = False):
def __init__(self, root_dir, transform, scene_scale = 1.0, is_train = True, use_alpha = False, white_bkg = False, is_cuda = False):
self.is_train = is_train
self.root_dir = root_dir
self.main_dir = root_dir + ("train/" if is_train else "test/")
Expand All @@ -43,6 +43,7 @@ def __init__(self, root_dir, transform, scene_scale = 1.0, is_train = True, use_
self.use_alpha = use_alpha
self.scene_scale = scene_scale
self.white_bkg = white_bkg
self.is_cuda = is_cuda
self.cam_fov, self.tfs = self.__get_camera_param()

def __len__(self):
Expand All @@ -52,16 +53,23 @@ def __getitem__(self, idx):
img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
image = Image.open(img_loc, mode = 'r').convert("RGBA" if self.use_alpha or self.white_bkg else "RGB")
tensor_image = self.transform(image)
if self.is_cuda:
tensor_image = tensor_image.cuda()
tf = self.tfs[idx].cuda()
else:
tf = self.tfs[idx].clone()
if self.white_bkg:
tensor_image = tensor_image[:3, ...]*tensor_image[-1:, ...] + (1.-tensor_image[-1:, ...])
tf = self.tfs[idx].clone()
tf[:3, -1] *= self.scene_scale
return tensor_image, tf
return tensor_image.squeeze(0), tf

def r_c(self):
image, _ = self.__getitem__(0)
return image.shape[1], image.shape[2]

def cuda(self, flag = True):
self.is_cuda = flag

@staticmethod
def readFromJson(path:str):
with open(path, "r") as file:
Expand Down
43 changes: 30 additions & 13 deletions py/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from py.ref_model import RefNeRF
from torchvision import transforms
from py.dataset import CustomDataSet, AdaptiveResize
from py.addtional import ProposalNetwork
from py.addtional import ProposalNetwork, SoftL1Loss, LossPSNR
from torch.nn.functional import softplus
from torchvision.utils import save_image
from py.mip_methods import maxBlurFilter
Expand Down Expand Up @@ -107,8 +107,9 @@ def render_only(args, model_path: str, opt_level: str):
use_white_bkg = args.white_bkg
opt_mode = args.opt_mode
use_ref_nerf = args.ref_nerf
render_normal = args.render_normal
render_depth = args.render_depth
eval_poses = args.eval_poses
render_normal = args.render_normal & (not eval_poses)
render_depth = args.render_depth & (not eval_poses)
transform_funcs = transforms.Compose([
AdaptiveResize(img_scale),
transforms.ToTensor(),
Expand All @@ -117,46 +118,61 @@ def render_only(args, model_path: str, opt_level: str):

cam_fov_test, _ = testset.getCameraParam()
r_c = testset.r_c()
del testset
if eval_poses:
all_poses = testset.tfs.cuda()
loss_func = SoftL1Loss()
psnr_func = LossPSNR()
else:
all_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in torch.linspace(-180,180,120 + 1)[:-1]], 0).cuda()
del testset
test_focal = fov2Focal(cam_fov_test, r_c)

if use_ref_nerf:
from py.ref_model import RefNeRF
mip_net = RefNeRF(10, args.ide_level, hidden_unit = 256, perturb_bottle_neck_w = args.bottle_neck_noise, use_srgb = args.use_srgb).cuda()
mip_net = RefNeRF(10, args.ide_level, hidden_unit = args.nerf_net_width, perturb_bottle_neck_w = args.bottle_neck_noise, use_srgb = args.use_srgb).cuda()
else:
from py.mip_model import MipNeRF
mip_net = MipNeRF(10, 4, hidden_unit = 256)
prop_net = ProposalNetwork(10, hidden_unit = 256).cuda()
mip_net = MipNeRF(10, 4, hidden_unit = args.nerf_net_width)
prop_net = ProposalNetwork(10, hidden_unit = args.prop_net_width).cuda()
if use_amp and opt_mode != "native":
from apex import amp
[mip_net, prop_net] = amp.initialize([mip_net, prop_net], None, opt_level = opt_level)
mip_net.loadFromFile(load_path_mip, use_amp and opt_mode != "native")
prop_net.loadFromFile(load_path_prop, use_amp and opt_mode != "native")

all_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in torch.linspace(-180,180,120 + 1)[:-1]], 0).cuda()
mip_net.eval()
prop_net.eval()
with torch.no_grad():
for i, pose in tqdm(list(enumerate(all_poses))):
pose[:3, -1] *= scene_scale
if opt_mode == "native":
with autocast():
result = render_image(mip_net, prop_net, pose[:-1, :], r_c, test_focal, near_t, far_t, 128,
result = render_image(mip_net, prop_net, pose[:3, :], r_c, test_focal, near_t, far_t, 128,
white_bkg = use_white_bkg, render_normal = render_normal, render_depth = render_depth)
else:
result = render_image(mip_net, prop_net, pose[:-1, :], r_c, test_focal, near_t, far_t, 128,
result = render_image(mip_net, prop_net, pose[:3, :], r_c, test_focal, near_t, far_t, 128,
white_bkg = use_white_bkg, render_normal = render_normal, render_depth = render_depth)
save_image(list(result.values()), "./output/sphere/result_%03d.png"%(i), nrow = 1 + render_depth + render_depth)
if eval_poses == True:
gt_img, _ = testset[i]
gt_img = gt_img.cuda()
loss = loss_func(result['rgb'], gt_img)
psnr = psnr_func(loss)
print("Image loss:%.6f\tPSNR:%.4f"%(loss.item(), psnr.item()))
result['gt_img'] = gt_img
output_dir = "given" if eval_poses else "sphere"
save_image(list(result.values()), "./output/%s/result_%03d.png"%(output_dir, i), nrow = 1 + render_depth + render_depth + eval_poses)

def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type = int, default = 2000, help = "Training lasts for . epochs")
parser.add_argument("--epochs", type = int, default = 2400, help = "Training lasts for . epochs")
parser.add_argument("--sample_ray_num", type = int, default = 1024, help = "<x> rays to sample per training time")
parser.add_argument("--coarse_sample_pnum", type = int, default = 64, help = "Points to sample in coarse net")
parser.add_argument("--fine_sample_pnum", type = int, default = 128, help = "Points to sample in fine net")
parser.add_argument("--eval_time", type = int, default = 5, help = "Tensorboard output interval (train time)")
parser.add_argument("--output_time", type = int, default = 20, help = "Image output interval (train time)")
parser.add_argument("--center_crop_iter", type = int, default = 0, help = "Produce center")
parser.add_argument("--prop_net_width", type = int, default = 256, help = "Width of proposal network")
parser.add_argument("--nerf_net_width", type = int, default = 256, help = "Width of nerf network")
parser.add_argument("--near", type = float, default = 2., help = "Nearest sample depth")
parser.add_argument("--far", type = float, default = 6., help = "Farthest sample depth")
parser.add_argument("--center_crop_x", type = float, default = 0.5, help = "Center crop x axis ratio")
Expand All @@ -165,7 +181,7 @@ def get_parser():
parser.add_argument("--dataset_name", type = str, default = "lego", help = "Input dataset name in nerf synthetic dataset")
parser.add_argument("--img_scale", type = float, default = 0.5, help = "Scale of the image")
parser.add_argument("--scene_scale", type = float, default = 1.0, help = "Scale of the scene")
parser.add_argument("--grad_clip", type = float, default = 1e-3, help = "Gradient clipping parameter")
parser.add_argument("--grad_clip", type = float, default = -0.01, help = "Gradient clipping parameter (Negative number means no clipping)")
parser.add_argument("--pe_period_scale", type = float, default = 0.5, help = "Scale of positional encoding")
# opt related
parser.add_argument("--opt_mode", type = str, default = "O1", help = "Optimization mode: none, native (torch amp), O1, O2 (apex amp)")
Expand All @@ -184,6 +200,7 @@ def get_parser():
parser.add_argument("-w", "--white_bkg", default = False, action = "store_true", help = "Output white background")
parser.add_argument("-t", "--ref_nerf", default = False, action = "store_true", help = "Test Ref NeRF")
parser.add_argument("-u", "--use_srgb", default = False, action = "store_true", help = "Whether to use srgb in the output or not")
parser.add_argument("-e", "--eval_poses", default = False, action = "store_true", help = "Whether to use test set poses to render image")
# long bool options
parser.add_argument("--render_depth", default = False, action = "store_true", help = "Render depth image")
parser.add_argument("--render_normal", default = False, action = "store_true", help = "Render normal image")
Expand Down
21 changes: 10 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def main(args):
from py.ref_model import RefNeRF, WeightedNormalLoss, BackFaceLoss
normal_loss_func = WeightedNormalLoss(True)
bf_loss_func = BackFaceLoss()
mip_net = RefNeRF(10, args.ide_level, hidden_unit = 256, perturb_bottle_neck_w = args.bottle_neck_noise, use_srgb = args.use_srgb).cuda()
mip_net = RefNeRF(10, args.ide_level, hidden_unit = args.nerf_net_width, perturb_bottle_neck_w = args.bottle_neck_noise, use_srgb = args.use_srgb).cuda()
else:
from py.mip_model import MipNeRF
mip_net = MipNeRF(10, 4, hidden_unit = 256).cuda()
prop_net = ProposalNetwork(10, hidden_unit = 256).cuda()
mip_net = MipNeRF(10, 4, hidden_unit = args.nerf_net_width).cuda()
prop_net = ProposalNetwork(10, hidden_unit = args.prop_net_width).cuda()

if debugging:
for submodule in mip_net.modules():
Expand All @@ -92,8 +92,10 @@ def main(args):
])

# 数据集加载
trainset = CustomDataSet("../dataset/refnerf/%s/"%(dataset_name), transform_funcs, scene_scale, True, use_alpha = False, white_bkg = use_white_bkg)
testset = CustomDataSet("../dataset/refnerf/%s/"%(dataset_name), transform_funcs, scene_scale, False, use_alpha = False, white_bkg = use_white_bkg)
trainset = CustomDataSet("../dataset/refnerf/%s/"%(dataset_name), transform_funcs,
scene_scale, True, use_alpha = False, white_bkg = use_white_bkg, is_cuda = True)
testset = CustomDataSet("../dataset/refnerf/%s/"%(dataset_name), transform_funcs,
scene_scale, False, use_alpha = False, white_bkg = use_white_bkg)
cam_fov_train, train_cam_tf = trainset.getCameraParam()
r_c = trainset.r_c()
train_cam_tf = train_cam_tf.cuda()
Expand Down Expand Up @@ -126,6 +128,7 @@ def grad_clip_func(parameters, grad_clip):
test_views = []
for i in (1, 4):
test_views.append(testset[i])
del testset
torch.cuda.empty_cache()

# ====== tensorboard summary writer ======
Expand All @@ -141,8 +144,6 @@ def grad_clip_func(parameters, grad_clip):
epoch_timer.tic()
for i, (train_img, train_tf) in enumerate(train_loader):
train_timer.tic()
train_img = train_img.cuda().squeeze(0)
train_tf = train_tf.cuda().squeeze(0)
now_crop = (center_crop if train_cnt < center_crop_iter else (1., 1.))
valid_pixels, valid_coords = randomFromOneImage(train_img, now_crop)

Expand Down Expand Up @@ -228,7 +229,7 @@ def run(is_ref_model = False):
with torch.no_grad():
eval_timer.tic()
test_results = []
test_loss = torch.zeros(1).cuda()
test_loss = 0.
for test_img, test_tf in test_views:
test_result = render_image(
mip_net, prop_net, test_tf.cuda(), r_c, test_focal, near_t, far_t, fine_sample_pnum,
Expand All @@ -242,9 +243,7 @@ def run(is_ref_model = False):
print("Evaluation in epoch: %4d / %4d\t, test counter: %d test loss: %.4f\taverage time: %.4lf\tremaining eval time:%s"%(
ep, epochs, test_cnt, test_loss.item() / 2, eval_timer.get_mean_time(), eval_timer.remaining_time(epochs - ep - 1)
))
images_to_save = []
images_to_save.extend(test_results)
save_image(images_to_save, "./output/result_%03d.png"%(test_cnt), nrow = 1 + render_normal + render_depth)
save_image(test_results, "./output/result_%03d.png"%(test_cnt), nrow = 1 + render_normal + render_depth)
# ======== Saving checkpoints ========
saveModel(mip_net, "%schkpt_%d_mip.pt"%(default_chkpt_path, train_cnt), {"train_cnt": train_cnt, "epoch": ep}, opt = opt, amp = (amp) if use_amp and opt_mode != "native" else None)
saveModel(prop_net, "%schkpt_%d_prop.pt"%(default_chkpt_path, train_cnt), opt = None, amp = (amp) if use_amp and opt_mode != "native" else None)
Expand Down
13 changes: 13 additions & 0 deletions train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
if [ "$1" = "" -o "$2" = "" ]; then
echo "Usage: ./train.sh <ray number (batch size)> <dataset name>"
echo "For example: ./train.sh 1024 car"
if [ "a"$1 = "a" ]; then
echo "Please specify batch size (ray number)"
fi
if [ "a"$2 = "a" ]; then
echo "Please specify dataset name (object name, for example: lego)"
fi
exit
fi

python3 ./train.py -s -t -u --sample_ray_num $1 --dataset_name $2 --render_depth --render_normal
1 change: 0 additions & 1 deletion train_helmet_v1.sh

This file was deleted.

0 comments on commit 0c482d0

Please sign in to comment.