Skip to content

Commit

Permalink
Code Release
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasSchult committed Jun 15, 2020
1 parent 8d449c2 commit 62a7088
Show file tree
Hide file tree
Showing 128 changed files with 12,987 additions and 30 deletions.
140 changes: 110 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,41 +1,121 @@
# Dual Convolution Mesh Network (DCM Net)
# DualConvMesh-Net: Joint Geodesic and Euclidean Convolutions on 3D Meshes
Created by [Jonas Schult*](https://www.vision.rwth-aachen.de/person/schult), [Francis Engelmann*](https://www.vision.rwth-aachen.de/person/14/), [Theodora Kontogianni](https://www.vision.rwth-aachen.de/person/15/) and [Bastian Leibe](https://www.vision.rwth-aachen.de/person/1/) from RWTH Aachen University.

![prediction example](doc/teaser.png)

## Coming soon...
Please *stay tuned*; we are currently working hard to get the code out quickly.
## Introduction
This work is based on our paper
[DualConvMesh-Net: Joint Geodesic and Euclidean Convolutions on 3D Meshes](https://arxiv.org/abs/2004.01002),
which appeared at the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2020.

You can also check our [project page](https://visualcomputinginstitute.github.io/dcm-net/) for further details.

We propose DualConvMesh-Nets (DCM-Net) a family of deep hierarchical convolutional networks over 3D geometric data that combines two types of convolutions. The first type, *geodesic convolutions*, defines the kernel weights over mesh surfaces or graphs. That is, the convolutional kernel weights are mapped to the local surface of a given mesh. The second type, *Euclidean convolutions*, is independent of any underlying mesh structure. The convolutional kernel is applied on a neighborhood obtained from a local affinity representation based on the Euclidean distance between 3D points. Intuitively, geodesic convolutions can easily separate objects that are spatially close but have disconnected surfaces, while Euclidean convolutions can represent interactions between nearby objects better, as they are oblivious to object surfaces. To realize a multi-resolution architecture, we borrow well-established mesh simplification methods from the geometry processing domain and adapt them to define mesh-preserving pooling and unpooling operations. We experimentally show that combining both types of convolutions in our architecture leads to significant performance gains for 3D semantic segmentation, and we report competitive results on three scene segmentation benchmarks.

*In this repository, we release code for training and testing DualConvMesh-Nets on arbitrary datasets.*

## Citation
If you find our work useful in your research, please consider citing us:

@inproceedings{Schult20CVPR,
author = {Jonas Schult* and
Francis Engelmann* and
Theodora Kontogianni and
Bastian Leibe},
title = {{DualConvMesh-Net: Joint Geodesic and Euclidean Convolutions on 3D Meshes}},
booktitle = {{IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}},
year = {2020}
}


## Installation
Our code requires **CUDA 10.0** for running correctly. Please make sure that your `$PATH`, `$CPATH` and `$LD_LIBRARBY_PATH` environment variables point to the right CUDA version.

conda deactivate
conda create -y -n dualmesh python=3.7
conda activate dualmesh

conda install -y -c open3d-admin open3d=0.6.0.0
conda install -y pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=10.0 -c pytorch
conda install -y -c conda-forge tensorboardx=1.7
conda install -y -c conda-forge tqdm=4.31.1
conda install -y -c omnia termcolor=1.1.0

# Execute pip installs one after each other
pip install --no-cache-dir torch-scatter==1.3.1
pip install --no-cache-dir torch_cluster==1.4.3
pip install --no-cache-dir torch_sparse==0.4.0
pip install "pillow<7" # torchvision bug

### Batching of hierarchical meshes (PyTorch Geometric Fork)
We created a fork of [PyTorch Geometric](https://github.com/JonasSchult/pytorch_geometric_fork) in order to support hierarchical mesh structures interlinked with pooling trace maps.

git clone https://github.com/JonasSchult/pytorch_geometric_fork.git
cd pytorch-geometric-fork
pip install --no-cache-dir .

### Mesh Simplification Preprocessing (VCGlib)
We adapted [VCGlib](https://github.com/JonasSchult/vcglib) to generate pooling trace maps for vertex clustering and quadric error metrics.

git clone https://github.com/JonasSchult/vcglib.git

# QUADRIC ERROR METRICS
cd vcglib/apps/tridecimator/
qmake
make

# VERTEX CLUSTERING
cd ../sample/trimesh_clustering
qmake
make

Add `vcglib/apps/tridecimator` and `vcglib/apps/sample/trimesh_clustering` to your environment path variable!

## Preparation

### Prepare the dataset
Please refer to https://github.com/ScanNet/ScanNet and https://github.com/niessner/Matterport to get access to the ScanNet and Matterport dataset. Our method relies on the .ply as well as the .labels.ply files.
We train on crops and we evaluate on full rooms.
After inserting the paths to the dataset and deciding on the parameters, execute the scripts in `utils/preprocess/scripts/{scannet, matterport}/rooms` and *subsequently* in `utils/preprocess/scripts/scannet, matterport}/crops` to generate mesh hierarchies on rooms and crop areas for training.
Please note that the scripts are developed for a SLURM batch system. If your lab does not use SLURM, please consider adapting the scripts for your purposes.
More information about the parameters are provided in the corresponding scripts in `utils/preprocess`.

### Symbolic Links pointing to the dataset
Create symlinks to the dataset such that our framework can find it.
For example:

ln -s /path/to/scannet/rooms/ data/scannet/scannet_qem_rooms

## Requirements
All of our dependencies can be installed with conda or pip.
* Python 3.7
* Open3D
* PyTorch 1.1 Cuda 10.0
* TensorboardX (Tensorflow and Tensorboard are unfortunately also needed to install this)
* Our fork of PyTorch Geometric (with its accompanying libraries as torch_scatter, torch_cluster, torch_sparse)
* tqdm
Alternatively, you can also directly set the paths in the corresponding experiment files.

Since we adapted PyTorch Geometric to enable graph level support, you need to install our fork as follows:

cd pytorch_geometric
python setup.py install
### Model Checkpoints
We provide the [model checkpoints](https://omnomnom.vision.rwth-aachen.de/data/dcm_net_checkpoints/dcm_net_checkpoints.zip) on our server.

## Preprocessing
Please refer to https://github.com/ScanNet/ScanNet to get access to the ScanNet dataset. Our method relies on the .ply as well as the .labels.ply files.
## Training
An example training script is given in `example_scripts/train_scannet.sh`

### Start a new training:
## Inference
An example inference script is given in `example_scripts/inference_scannet.sh`

python train_wrapper.py \
-c PATH_TO_EXPERIMENTS_FILE.json

### Resume a training:
## Visualization
An example visualization script is given in `example_scripts/visualize_scannet.sh`.
We show qualitative results on the ScanNet validation set.
Please note that a symlink to the ScanNet mesh folder has to be in placed in `data/scannet/scans`.
The visualization tool is based on [open3D](http://www.open3d.org/) and handles the following key events:
* h = RGB
* j = prediction
* k = ground truth
* f = color-coded positive/negative predictions
* l = local lighting on/off
* s = smoothing mesh on/off
* b = back-face culling on/off
* d = save current meshes as .ply in `visualizations/` folder (useful, if you plan to make some decent rendering with Blender, later on :) )
* q = quit and show next room

python train_wrapper.py \
-c PATH_TO_EXPERIMENTS_FILE.json \
-r PATH_TO_CHECKPOINT.pth
Use your mouse to navigate in the mesh.

### Reproduce the scores of our paper:
## ToDo's
- Preprocessing code for S3DIS data set

python run.py \
-c experiments/EXPERIMENT_NAME.json \
-r paper_checkpoints/EXPERIMENT_NAME.pth \
-e
## Acknowledgements
This project is based on the [PyTorch-Template](https://github.com/victoresque/pytorch-template) by [@victoresque](https://github.com/victoresque).
3 changes: 3 additions & 0 deletions base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base_model import *
from .base_trainer import *
from .base_dataset import *
189 changes: 189 additions & 0 deletions base/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import torch
import logging
import gc
from abc import ABCMeta, abstractmethod
from typing import Callable, List
from torch.utils.data import Dataset
from torch_geometric.data import Data
import random

class BaseDataSet(Dataset):

def __init__(self, root_dir: str,
start_level: int,
end_level: int,
get_coords: bool = False,
benchmark: bool = False,
is_train: bool = True,
debug_mode: bool = False,
transform: List[Callable] = None,
original_meshes_dir: str = None,
sample_checker: List[Callable] = None,
include_edges: bool = True):
"""Base class for data set objects.
Arguments:
root_dir {str} -- Root path to folder containing train/test samples
start_level {int} -- First hierarchy level of mesh
end_level {int} -- Final level of mesh
Keyword Arguments:
get_coords {bool} -- load coordinates of vertices as well (default: {False})
benchmark {bool} -- benchmark mode (e.g. ScanNet test set) (default: {False})
is_train {bool} -- train or validation mode? (default: {True})
debug_mode {bool} -- also load vertex positions and colors of later levels (default: {False})
transform {List[Callable]} -- list of callable objects to transform data samples (data augmentation) (default: {None})
original_meshes_dir {str} -- root path of folder containing the original data set meshes (visualization and evaluation) (default: {None})
sample_checker {List[Callable]} -- list of callable objects to reject samples of low quality (e.g. number of unlabeled vertices) (default: {None})
include_edges {bool} -- only use point cloud information (default: {True})
"""

self._root_dir = root_dir
self._transform = transform
self._start_level = start_level
self._end_level = end_level
self._debug_mode = debug_mode
self._is_train = is_train
self._original_meshes_dir = original_meshes_dir
self._sample_checker = sample_checker
self._benchmark = benchmark
self._include_edges = include_edges

self._get_coords = get_coords

if self._sample_checker is None:
self._sample_checker = [lambda x: True]

self.index2filenames = self._load(self._is_train, self._benchmark)

self.logger = logging.getLogger(self.__class__.__name__)

def _load(self, is_train: bool, benchmark: bool) -> List[str]:
"""Subclasses implement function which returns file names of train/val samples
"""
raise NotImplementedError("")

def __getitem__(self, index: int) -> List[Data]:
"""[summary]
Arguments:
index {int} -- [description]
Raises:
RuntimeError: [description]
Returns:
List[Data] -- [description]
"""
sample = None

try:
name = self.index2filenames[index]

file_path = f"{self._root_dir}/{name}"

saved_tensors = torch.load(file_path)

coords = saved_tensors['vertices'][:self._end_level]

if not self._benchmark:
labels = saved_tensors['labels']
else:
labels = None

edges = saved_tensors['edges'][:self._end_level]

if self._is_train:
traces = saved_tensors['traces'][:self._end_level-1]
else:
trace_0 = saved_tensors['traces'][0]
traces = saved_tensors['traces'][1:self._end_level]

sample = Data(x=coords[0][:, 3:],
pos=coords[0][:, :3],
edge_index=edges[0].t().contiguous(
) if self._include_edges else None,
y=labels)
sample.name = name

nested_meshes = []

for level in range(1, len(edges)):
data = Data(edge_index=edges[level].t(
).contiguous() if self._include_edges else None)

data.trace_index = traces[level-1]

if self._debug_mode:
data.x = coords[level][:, 3:]
data.pos = coords[level][:, :3]

if self._get_coords:
data.pos = coords[level][:, :3]

nested_meshes.append(data)

sample.num_vertices = []
for level, nested_mesh in enumerate(nested_meshes):
setattr(
sample, f"hierarchy_edge_index_{level+1}", nested_mesh.edge_index)
setattr(
sample, f"hierarchy_trace_index_{level+1}", nested_mesh.trace_index)

sample.num_vertices.append(
int(sample[f"hierarchy_trace_index_{level+1}"].max() + 1))

if self._get_coords:
setattr(sample, f"pos_{level + 1}", nested_mesh.pos)

if not self._is_train:
sample.original_index_traces = trace_0

if self._debug_mode:
sample.all_levels = [sample]
sample.all_levels.extend(nested_meshes)

if self._transform:
sample = self._transform(sample)

for checker in self._sample_checker:
if not checker(sample):
raise RuntimeError(
f"{checker.__class__.__name__} rejected the sample")

return sample
except Exception as e:
if sample is not None:
del sample
gc.collect()
self.logger.warning(
f"Warning: Training example {index} could not be processed")
self.logger.warning(f"{str(e)}")
index = random.randrange(len(self))
return self.__getitem__(index)

def __len__(self):
return len(self.index2filenames)

@property
@abstractmethod
def color_map(self):
"""Dataset has to declare to which color each class is mapped for visualization purposes
"""
raise NotImplementedError("")

@property
def num_classes(self):
return len(self.color_map)

@property
@abstractmethod
def ignore_classes(self) -> int:
"""There exist unlimited vertices. Specify ID to ignore them in calculation of loss and metrics."""
raise NotImplementedError("")

pos_neg_map = torch.FloatTensor(
[
[200, 200, 200],
[0, 255, 0],
[255, 0, 0]])
38 changes: 38 additions & 0 deletions base/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging
import torch.nn as nn
import numpy as np


class BaseModel(nn.Module):
"""
Base class for all models
"""
def __init__(self):
super(BaseModel, self).__init__()
self.logger = logging.getLogger(self.__class__.__name__)

def forward(self, *input):
"""
Forward pass logic
:return: Model output
"""
raise NotImplementedError

def summary(self):
"""
Model summary
"""
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
self.logger.info('Trainable parameters: {}'.format(params))
self.logger.info(self)

def __str__(self):
"""
Model prints with number of trainable parameters
"""
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return super(BaseModel, self).__str__() + '\nTrainable parameters: {}'.format(params)
# print(super(BaseModel, self))
Loading

0 comments on commit 62a7088

Please sign in to comment.