Skip to content

Commit

Permalink
Bug fix for camera ready release
Browse files Browse the repository at this point in the history
  • Loading branch information
zhou13 committed Jan 10, 2020
1 parent 14545f5 commit d3ce4f0
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 17 deletions.
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ NeurVPS is an end-to-end trainable deep network with *geometry-inspired* convolu
| ------------------------------------------------------------ | ------------------------------------------------------------ | ----------------------------------- |
| ![blend](figs/su3.png) | ![tmm17](figs/tmm17.png) | ![scannet](figs/scannet.png) |

Some random sampled results can also be found in the [supplementary material](https://yichaozhou.com/publication/1905neurvps/appendix.pdf) of the paper.

### Quantitative Measures

| [SceneCity Urban 3D (SU3)](https://arxiv.org/abs/1905.07482) | [Natural Scene (TMM17)](https://faculty.ist.psu.edu/zzhou/projects/vpdetection/) | [ScanNet](http://www.scan-net.org/) |
| ------------------------------------------------------------ | ------------------------------------------------------------ | -------------------------------------- |
| ![su3_AA6](figs/su3_AA6.svg) | ![tmm17_AA20](figs/tmm17_AA20.svg) | ![scannet_AA20](figs/scannet_AA20.svg) |

Here, the x-axis represents the angular error of the detected vanishing points and the y-axis represents the percentage of the results whose error is less than that. Our conic convolutional networks outperform all the baseline methods and previous state-of-the-art vanishing point detection approaches, while naive CNN implementations might underperform those traditional methods, espeically in the high-accuracy regions.
Here, the x-axis represents the angle accuracy of the detected vanishing points and the y-axis represents the percentage of the results whose error is less than that. Our conic convolutional networks outperform all the baseline methods and previous state-of-the-art vanishing point detection approaches, while naive CNN implementations might under-perform those traditional methods, especially in the high-accuracy regions.

### Code Structure

Expand Down Expand Up @@ -108,7 +110,7 @@ The checkpoints and logs will be written to `logs/` accordingly. It has been rep
### Pre-trained Models

You can download our reference pre-trained models from [Google
Drive](https://drive.google.com/drive/folders/1srniSE2JD6ptAwc_QRnpl7uQnB5jLNIZ). Those pretrained
Drive](https://drive.google.com/drive/folders/1srniSE2JD6ptAwc_QRnpl7uQnB5jLNIZ). Those pre-trained
models should be able to reproduce the numbers in our paper.

### Evaluation
Expand All @@ -121,12 +123,15 @@ python eval.py -d 0 logs/YOUR_LOG/config.yaml logs/YOUR_LOG/checkpoint_best.pth.
### FAQ

#### What is the unit of focal length in the yaml and why do I need it?
**A:** The focal length in our implementation is in the unit of 2/w pixel (w is the image width. only a square image is supported). This follows the convention of the OpenGL projection matrix so that to make it resolution invariant. The focal length is used for uniform sampling of the position of vanishing points. If it is not known, you can set it to some common focal length for your catorgories of images, as we do in [config/tmm17.yaml](https://github.com/zhou13/neurvps/blob/master/config/tmm17.yaml).
**A:** The focal length in our implementation is in the unit of 2/w pixel (w is the image width. only a square image is supported). This follows the convention of the OpenGL projection matrix so that to make it resolution invariant. The focal length is used for uniform sampling of the position of vanishing points. If it is not known, you can set it to some common focal length for your categories of images, as we do in [config/tmm17.yaml](https://github.com/zhou13/neurvps/blob/master/config/tmm17.yaml).

You can also check the function `to_label` and `to_pixel`, which use the focal length to convert the 3D line direction from and to a 2D vanishing point.

#### I have a question. How could I get help?
**A:** You can post an issue on Github, which may help other people that have the same question. You can also send me an email if you think that is more appropriate.
**A:** You can post an issue on Github, which might help other people that have the same question. You can also send me an email if you think that is more appropriate.

### Acknowledgement
We thank Yikai Li from SJTU and Jiajun Wu from MIT for pointing out an error in the data augmentation code for the TMM17 Natural Scene dataset. This work is partially supported by the funding from Berkeley EECS Startup fund, Berkeley FHL Vive Center for Enhanced Reality, research grants from Sony Research, and Bytedance Research Lab (Silicon Valley).

### Citing NeurVPS

Expand Down
1 change: 1 addition & 0 deletions config/scannet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ model:
smp_rnd: 3
smp_multiplier: 2
multires:
# those magic numbers are calculated by misc/find-radius.py
- 0.0200483803479500
- 0.0774278195486317
- 0.2995648108645650
Expand Down
1 change: 1 addition & 0 deletions config/su3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ model:
smp_rnd: 3
smp_multiplier: 2
multires:
# those magic numbers are calculated by misc/find-radius.py
- 0.0013457768043554
- 0.0051941870036646
- 0.0200483803479500
Expand Down
11 changes: 6 additions & 5 deletions config/tmm17.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ io:
num_workers: 4
tensorboard_port: 0
validation_interval: 8000
validation_debug: 160
validation_debug: 0
focal_length: 1.0
num_vpts: 1
augmentation_level: 2
Expand All @@ -27,6 +27,7 @@ model:
smp_rnd: 3
smp_multiplier: 2
multires:
# those magic numbers are calculated by misc/find-radius.py
- 0.0051941870036646
- 0.0200483803479500
- 0.0774278195486317
Expand All @@ -38,8 +39,8 @@ model:

optim:
name: Adam
lr: 3.0e-4
lr: 1.0e-4
amsgrad: True
weight_decay: 3.0e-4
max_epoch: 50
lr_decay_epoch: 30
weight_decay: 6.0e-4
max_epoch: 100
lr_decay_epoch: 60
18 changes: 12 additions & 6 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
-h --help Show this screen
-d --devices <devices> Comma seperated GPU devices [default: 0]
-o --output <output> Path to the output AA curve [default: error.npz]
--dump <output-dir> Optionally, save the vanishing points in npz format
--dump <output-dir> Optionally, save the vanishing points to npz format.
The coordinate of VPs is in the camera space, see
`to_label` and `to_pixel` in neurvps/models/vanishing_net.py
for more details.
--noimshow Do not show result
"""

import os
Expand Down Expand Up @@ -102,7 +106,7 @@ def main():
Dataset(C.io.datadir, split="valid"),
batch_size=1,
shuffle=False,
num_workers=C.io.num_workers,
num_workers=C.io.num_workers if os.name != "nt" else 0,
pin_memory=True,
)

Expand Down Expand Up @@ -151,10 +155,12 @@ def main():
err = np.sort(np.array(err))
np.savez(args["--output"], err=err)
y = (1 + np.arange(len(err))) / len(loader) / n
plt.plot(err, y, label="Conic")
print(" | ".join([f"{AA(err, y, th):.3f}" for th in [0.2, 0.5, 1.0, 2.0, 4.0]]))
plt.legend()
plt.show()

if not args["--noimshow"]:
plt.plot(err, y, label="Conic")
print(" | ".join([f"{AA(err, y, th):.3f}" for th in [0.5, 1, 2, 5, 10, 20]]))
plt.legend()
plt.show()


def sample_sphere(v, alpha, num_pts):
Expand Down
2 changes: 1 addition & 1 deletion neurvps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def __getitem__(self, idx):
vpts = np.array([[xy[0] / 256 - 1, 1 - xy[1] / 256, C.io.focal_length]])
vpts[0] /= LA.norm(vpts[0])

image = np.rollaxis(image, 2)
image, vpts = augment(image, vpts, idx // len(self.filelist))
image = np.rollaxis(image, 2)
return (torch.tensor(image * 255).float(), {"vpts": torch.tensor(vpts).float()})


Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def main():
datadir = C.io.datadir
kwargs = {
"batch_size": batch_size,
"num_workers": C.io.num_workers,
"num_workers": C.io.num_workers if os.name != "nt" else 0,
"pin_memory": True,
}
if C.io.dataset.upper() == "WIREFRAME":
Expand Down

0 comments on commit d3ce4f0

Please sign in to comment.