Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
zhou13 committed Oct 14, 2019
0 parents commit cdd3f2a
Show file tree
Hide file tree
Showing 33 changed files with 15,226 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/neurvps/models/cpp/build/
/data
/logs
.autoenv*
.idea/
*.pyc
/last_err.npz
/error.npz
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2019 Yichao Zhou

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
133 changes: 133 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# NeurVPS: Neural Vanishing Point Scanning via Conic Convolution

This repository contains the official PyTorch implementation of the paper: *[Yichao Zhou](https://yichaozhou.com), [Haozhi Qi](http://haozhi.io), [Jingwei Huang](http://stanford.edu/~jingweih/), [Yi Ma](https://people.eecs.berkeley.edu/~yima/). ["NeurVPS: Neural Vanishing Point Scanning via Conic Convolution"](https://arXiv/Link)*.

## Introduction

NeurVPS is an end-to-end trainable deep network with *geometry-inspired* convolutional operators for detecting vanishing points in images. With the power of data-driven approaches and geometrical priors, NeurVPS is able to outperform the previous state-of-the-art vanishing point detection methods such as [LSD/J-Linkage](https://github.com/simbaforrest/vpdetection) and [Contour (TMM17)](https://github.com/heiheiknight/vpdet_tmm17).

## Main Results

### Qualitative Measures

| [SceneCity Urban 3D (SU3)](https://arxiv.org/abs/1905.07482) | [Natural Scene (TMM17)](https://faculty.ist.psu.edu/zzhou/projects/vpdetection/) | [ScanNet](http://www.scan-net.org/) |
| ------------------------------------------------------------ | ------------------------------------------------------------ | ----------------------------------- |
| ![blend](figs/su3.png) | ![tmm17](figs/tmm17.png) | ![scannet](figs/scannet.png) |

### Quantitative Measures

| [SceneCity Urban 3D (SU3)](https://arxiv.org/abs/1905.07482) | [Natural Scene (TMM17)](https://faculty.ist.psu.edu/zzhou/projects/vpdetection/) | [ScanNet](http://www.scan-net.org/) |
| ------------------------------------------------------------ | ------------------------------------------------------------ | -------------------------------------- |
| ![su3_AA6](figs/su3_AA6.svg) | ![tmm17_AA20](figs/tmm17_AA20.svg) | ![scannet_AA20](figs/scannet_AA20.svg) |

Here, the x-axis represents the angular error of the detected vanishing points and the y-axis represents the percentage of the results whose error is less than that. Our conic convolutional networks outperform all the baseline methods and previous state-of-the-art vanishing point detection approaches, while naive CNN implementations might underperform those traidional methods, espeically in the high-accuracy regions.

### Code Structure

Below is a quick overview of the function of each file.

```bash
########################### Data ###########################
data/ # default folder for placing the data
su3/ # folder for SU3 dataset
tmm17/ # folder for TMM17 dataset
scannet-vp/ # folder for ScanNet dataset
logs/ # default folder for storing the output during training
########################### Code ###########################
config/ # neural network hyper-parameters and configurations
su3.yaml # default parameters for SU3 dataset
tmm17.yaml # default parameters for TMM17 dataset
scannet.yaml # default parameters for scannet dataset
dataset/ # all scripts related to data generation
su3.py # script for pre-processing the SU3 dataset to npz
misc/ # misc scripts that are not important
find-radius.py # script for generating figure grids
neurvps/ # neurvps module so you can "import neurvps" in other scripts
models/ # neural network architectures
cpp/ # CUDA kernel for deformable convolution
deformable.py # python wrapper for deformable convolution layers
conic.py # conic convolution layers
hourglass_pose.py # backbone network
vanishing_net.py # main network
datasets.py # reading the training data
trainer.py # trainer
config.py # global variables for configuration
utils.py # misc functions
train.py # script for training the neural network
eval.py # script for evaluating a dataset from a checkpoint
```

## Reproducing Results

### Installation

For the ease of reproducibility, you are suggested to install [miniconda](https://docs.conda.io/en/latest/miniconda.html) (or [anaconda](https://www.anaconda.com/distribution/) if you prefer) before following executing the following commands.

```bash
git clone https://github.com/zhou13/neurvps
cd neurvps
conda create -y -n neurvps
source activate neurvps
# Replace cudatoolkit=10.1 with your CUDA version: https://pytorch.org/get-started/
conda install -y pytorch cudatoolkit=10.1 -c pytorch
conda install -y tensorboardx -c conda-forge
conda install -y pyyaml docopt matplotlib scikit-image opencv tqdm
mkdir data logs
```

### Downloading the Processed Datasets
Make sure `curl` is installed on your system and execute
```bash
cd data
../misc/gdrive-download.sh 1yRwLv28ozRvjsf9wGwAqzya1xFZ5wYET su3.tar.xz
../misc/gdrive-download.sh 1rpQNbZQEUff2j2rxr3mBl6xohGFl6sLv tmm17.tar.xz
../misc/gdrive-download.sh 1y_O9PxZhJ_Ml297FgoWMBLvjC1BvTs9A scannet.tar.xz
tar xf su3.tar.xz
tar xf tmm17.tar.xz
tar xf scannet.tar.xz
rm *.tar.xz
cd ..
```

If `gdrive-download.sh` does not work for you, you can download the pre-processed datasets
manually from our [Google
Drive](https://drive.google.com/drive/folders/1xBcHj584zGxhMboZNJHWlAe_XIbHfC34) and proceed
accordingly.


### Training
Execute the following commands to train the neural networks from scratch on 2 GPUs (GPU 0 and GPU 1, specified by `-d 0,1`) with the default parameters:
```bash
python ./train.py -d 0,1 --identifier su3 config/su3.yaml
python ./train.py -d 0,1 --identifier tmm17 config/tmm17.yaml
python ./train.py -d 0,1 --identifier scannet config/scannet.yaml
```

The checkpoints and logs will be written to `logs/` accordingly. It has been reported that it is possible to achieve higher performance with 4-GPU training, though the training process is more volatile.

### Pre-trained Models

You can download our reference pre-trained models from [Google
Drive](https://drive.google.com/drive/folders/1srniSE2JD6ptAwc_QRnpl7uQnB5jLNIZ). Those pretrained
models should be able to reproduce the numbers in our paper.

### Evaluation

Execute the following commands to compute and plot the angular accuracy (AA) curves with trained network checkpoints:

```bash
python eval.py -d 0 logs/YOUR_LOG/config.yaml logs/YOUR_LOG/checkpoint_best.pth.tar
```

### Citing NeurVPS

If you find NeurVPS useful in your research, please consider citing:

```
@inproceedings{zhou2019end,
author={Zhou, Yichao and Qi, Haozhi and Huang, Jingwei and Ma, Yi},
title={NeurVPS: Neural Vanishing Point Scanning via Conic Convolution},
booktitle={NeurIPS},
year={2019}
}
```
44 changes: 44 additions & 0 deletions config/scannet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
io:
logdir: logs/
datadir: data/scannet-vp
dataset: ScanNet
resume_from:
num_workers: 6
tensorboard_port: 0
validation_interval: 24000
validation_debug: 240
focal_length: 2.408333333333333
num_vpts: 3
augmentation_level: 2

model:
batch_size: 8
im2col_step: 21
backbone: stacked_hourglass
depth: 4
num_stacks: 1
num_blocks: 1

fc_channel: 1024

# reg2classfication
smp_pos: 1
smp_neg: 1
smp_rnd: 3
smp_multiplier: 2
multires:
- 0.0200483803479500
- 0.0774278195486317
- 0.2995648108645650
output_stride: 4
upsample_scale: 1

conic_6x: True

optim:
name: Adam
lr: 1.0e-4
amsgrad: True
weight_decay: 1.0e-5
max_epoch: 36
lr_decay_epoch: 3
46 changes: 46 additions & 0 deletions config/su3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
io:
logdir: logs/
datadir: data/su3/
dataset: Wireframe
resume_from:
num_workers: 6
tensorboard_port: 0
validation_interval: 24000
validation_debug: 120
focal_length: 2.1875
num_vpts: 3
augmentation_level: 2

model:
batch_size: 6
im2col_step: 11
backbone: stacked_hourglass
depth: 4
num_stacks: 1
num_blocks: 1

fc_channel: 1024

# reg2classfication
smp_pos: 1
smp_neg: 1
smp_rnd: 3
smp_multiplier: 2
multires:
- 0.0013457768043554
- 0.0051941870036646
- 0.0200483803479500
- 0.0774278195486317
- 0.2995648108645650
output_stride: 4
upsample_scale: 1

conic_6x: False

optim:
name: Adam
lr: 1.0e-4
amsgrad: True
weight_decay: 1.0e-5
max_epoch: 36
lr_decay_epoch: 24
45 changes: 45 additions & 0 deletions config/tmm17.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
io:
logdir: logs/
datadir: data/tmm17/
dataset: TMM17
resume_from:
num_workers: 4
tensorboard_port: 0
validation_interval: 8000
validation_debug: 160
focal_length: 1.0
num_vpts: 1
augmentation_level: 2

model:
batch_size: 8
im2col_step: 11
backbone: stacked_hourglass
depth: 4
num_stacks: 1
num_blocks: 1

fc_channel: 1024

# reg2classfication
smp_pos: 1
smp_neg: 1
smp_rnd: 3
smp_multiplier: 2
multires:
- 0.0051941870036646
- 0.0200483803479500
- 0.0774278195486317
- 0.2995648108645650
output_stride: 4
upsample_scale: 1

conic_6x: False

optim:
name: Adam
lr: 3.0e-4
amsgrad: True
weight_decay: 3.0e-4
max_epoch: 50
lr_decay_epoch: 30
60 changes: 60 additions & 0 deletions dataset/su3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python
"""Preprocess the SU3 dataset for NeurVPS
Usage:
dataset/su3.py <dir>
dataset/su3.py (-h | --help )
Arguments:
<dir> Target directory
Options:
-h --help Show this screen.
"""
import os
import sys
import json
from glob import glob

import numpy as np
import numpy.linalg as LA
import matplotlib.pyplot as plt
from docopt import docopt
from skimage import io

try:
sys.path.append(".")
sys.path.append("..")
from neurvps.utils import parmap
except Exception:
raise


def handle(iname):
prefix = iname.replace(".png", "")
with open(f"{prefix}_camera.json") as f:
js = json.load(f)
RT = np.array(js["modelview_matrix"])

vpts = []
# plt.imshow(io.imread(iname))
for axis in [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]:
vp = RT @ axis
vp = np.array([vp[0], vp[1], -vp[2]])
vp /= LA.norm(vp)
vpts.append(vp)
# plt.scatter(
# vpt[0] / vpt[2] * 2.1875 * 256 + 256,
# -vpt[1] / vpt[2] * 2.1875 * 256 + 256
# )
# plt.show()
np.savez_compressed(f"{prefix}_label.npz", vpts=np.array(vpts))


def main():
args = docopt(__doc__)
filelist = sorted(glob(args["<dir>"] + "/*/????_0.png"))
parmap(handle, filelist)


if __name__ == "__main__":
main()
Loading

0 comments on commit cdd3f2a

Please sign in to comment.