-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8d449c2
commit 62a7088
Showing
128 changed files
with
12,987 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,121 @@ | ||
# Dual Convolution Mesh Network (DCM Net) | ||
# DualConvMesh-Net: Joint Geodesic and Euclidean Convolutions on 3D Meshes | ||
Created by [Jonas Schult*](https://www.vision.rwth-aachen.de/person/schult), [Francis Engelmann*](https://www.vision.rwth-aachen.de/person/14/), [Theodora Kontogianni](https://www.vision.rwth-aachen.de/person/15/) and [Bastian Leibe](https://www.vision.rwth-aachen.de/person/1/) from RWTH Aachen University. | ||
|
||
data:image/s3,"s3://crabby-images/0d6bb/0d6bb371302f3148013284575e770cf77dc07cec" alt="prediction example" | ||
|
||
## Coming soon... | ||
Please *stay tuned*; we are currently working hard to get the code out quickly. | ||
## Introduction | ||
This work is based on our paper | ||
[DualConvMesh-Net: Joint Geodesic and Euclidean Convolutions on 3D Meshes](https://arxiv.org/abs/2004.01002), | ||
which appeared at the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2020. | ||
|
||
You can also check our [project page](https://visualcomputinginstitute.github.io/dcm-net/) for further details. | ||
|
||
We propose DualConvMesh-Nets (DCM-Net) a family of deep hierarchical convolutional networks over 3D geometric data that combines two types of convolutions. The first type, *geodesic convolutions*, defines the kernel weights over mesh surfaces or graphs. That is, the convolutional kernel weights are mapped to the local surface of a given mesh. The second type, *Euclidean convolutions*, is independent of any underlying mesh structure. The convolutional kernel is applied on a neighborhood obtained from a local affinity representation based on the Euclidean distance between 3D points. Intuitively, geodesic convolutions can easily separate objects that are spatially close but have disconnected surfaces, while Euclidean convolutions can represent interactions between nearby objects better, as they are oblivious to object surfaces. To realize a multi-resolution architecture, we borrow well-established mesh simplification methods from the geometry processing domain and adapt them to define mesh-preserving pooling and unpooling operations. We experimentally show that combining both types of convolutions in our architecture leads to significant performance gains for 3D semantic segmentation, and we report competitive results on three scene segmentation benchmarks. | ||
|
||
*In this repository, we release code for training and testing DualConvMesh-Nets on arbitrary datasets.* | ||
|
||
## Citation | ||
If you find our work useful in your research, please consider citing us: | ||
|
||
@inproceedings{Schult20CVPR, | ||
author = {Jonas Schult* and | ||
Francis Engelmann* and | ||
Theodora Kontogianni and | ||
Bastian Leibe}, | ||
title = {{DualConvMesh-Net: Joint Geodesic and Euclidean Convolutions on 3D Meshes}}, | ||
booktitle = {{IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}}, | ||
year = {2020} | ||
} | ||
|
||
|
||
## Installation | ||
Our code requires **CUDA 10.0** for running correctly. Please make sure that your `$PATH`, `$CPATH` and `$LD_LIBRARBY_PATH` environment variables point to the right CUDA version. | ||
|
||
conda deactivate | ||
conda create -y -n dualmesh python=3.7 | ||
conda activate dualmesh | ||
|
||
conda install -y -c open3d-admin open3d=0.6.0.0 | ||
conda install -y pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=10.0 -c pytorch | ||
conda install -y -c conda-forge tensorboardx=1.7 | ||
conda install -y -c conda-forge tqdm=4.31.1 | ||
conda install -y -c omnia termcolor=1.1.0 | ||
|
||
# Execute pip installs one after each other | ||
pip install --no-cache-dir torch-scatter==1.3.1 | ||
pip install --no-cache-dir torch_cluster==1.4.3 | ||
pip install --no-cache-dir torch_sparse==0.4.0 | ||
pip install "pillow<7" # torchvision bug | ||
|
||
### Batching of hierarchical meshes (PyTorch Geometric Fork) | ||
We created a fork of [PyTorch Geometric](https://github.com/JonasSchult/pytorch_geometric_fork) in order to support hierarchical mesh structures interlinked with pooling trace maps. | ||
|
||
git clone https://github.com/JonasSchult/pytorch_geometric_fork.git | ||
cd pytorch-geometric-fork | ||
pip install --no-cache-dir . | ||
|
||
### Mesh Simplification Preprocessing (VCGlib) | ||
We adapted [VCGlib](https://github.com/JonasSchult/vcglib) to generate pooling trace maps for vertex clustering and quadric error metrics. | ||
|
||
git clone https://github.com/JonasSchult/vcglib.git | ||
|
||
# QUADRIC ERROR METRICS | ||
cd vcglib/apps/tridecimator/ | ||
qmake | ||
make | ||
|
||
# VERTEX CLUSTERING | ||
cd ../sample/trimesh_clustering | ||
qmake | ||
make | ||
|
||
Add `vcglib/apps/tridecimator` and `vcglib/apps/sample/trimesh_clustering` to your environment path variable! | ||
|
||
## Preparation | ||
|
||
### Prepare the dataset | ||
Please refer to https://github.com/ScanNet/ScanNet and https://github.com/niessner/Matterport to get access to the ScanNet and Matterport dataset. Our method relies on the .ply as well as the .labels.ply files. | ||
We train on crops and we evaluate on full rooms. | ||
After inserting the paths to the dataset and deciding on the parameters, execute the scripts in `utils/preprocess/scripts/{scannet, matterport}/rooms` and *subsequently* in `utils/preprocess/scripts/scannet, matterport}/crops` to generate mesh hierarchies on rooms and crop areas for training. | ||
Please note that the scripts are developed for a SLURM batch system. If your lab does not use SLURM, please consider adapting the scripts for your purposes. | ||
More information about the parameters are provided in the corresponding scripts in `utils/preprocess`. | ||
|
||
### Symbolic Links pointing to the dataset | ||
Create symlinks to the dataset such that our framework can find it. | ||
For example: | ||
|
||
ln -s /path/to/scannet/rooms/ data/scannet/scannet_qem_rooms | ||
|
||
## Requirements | ||
All of our dependencies can be installed with conda or pip. | ||
* Python 3.7 | ||
* Open3D | ||
* PyTorch 1.1 Cuda 10.0 | ||
* TensorboardX (Tensorflow and Tensorboard are unfortunately also needed to install this) | ||
* Our fork of PyTorch Geometric (with its accompanying libraries as torch_scatter, torch_cluster, torch_sparse) | ||
* tqdm | ||
Alternatively, you can also directly set the paths in the corresponding experiment files. | ||
|
||
Since we adapted PyTorch Geometric to enable graph level support, you need to install our fork as follows: | ||
|
||
cd pytorch_geometric | ||
python setup.py install | ||
### Model Checkpoints | ||
We provide the [model checkpoints](https://omnomnom.vision.rwth-aachen.de/data/dcm_net_checkpoints/dcm_net_checkpoints.zip) on our server. | ||
|
||
## Preprocessing | ||
Please refer to https://github.com/ScanNet/ScanNet to get access to the ScanNet dataset. Our method relies on the .ply as well as the .labels.ply files. | ||
## Training | ||
An example training script is given in `example_scripts/train_scannet.sh` | ||
|
||
### Start a new training: | ||
## Inference | ||
An example inference script is given in `example_scripts/inference_scannet.sh` | ||
|
||
python train_wrapper.py \ | ||
-c PATH_TO_EXPERIMENTS_FILE.json | ||
|
||
### Resume a training: | ||
## Visualization | ||
An example visualization script is given in `example_scripts/visualize_scannet.sh`. | ||
We show qualitative results on the ScanNet validation set. | ||
Please note that a symlink to the ScanNet mesh folder has to be in placed in `data/scannet/scans`. | ||
The visualization tool is based on [open3D](http://www.open3d.org/) and handles the following key events: | ||
* h = RGB | ||
* j = prediction | ||
* k = ground truth | ||
* f = color-coded positive/negative predictions | ||
* l = local lighting on/off | ||
* s = smoothing mesh on/off | ||
* b = back-face culling on/off | ||
* d = save current meshes as .ply in `visualizations/` folder (useful, if you plan to make some decent rendering with Blender, later on :) ) | ||
* q = quit and show next room | ||
|
||
python train_wrapper.py \ | ||
-c PATH_TO_EXPERIMENTS_FILE.json \ | ||
-r PATH_TO_CHECKPOINT.pth | ||
Use your mouse to navigate in the mesh. | ||
|
||
### Reproduce the scores of our paper: | ||
## ToDo's | ||
- Preprocessing code for S3DIS data set | ||
|
||
python run.py \ | ||
-c experiments/EXPERIMENT_NAME.json \ | ||
-r paper_checkpoints/EXPERIMENT_NAME.pth \ | ||
-e | ||
## Acknowledgements | ||
This project is based on the [PyTorch-Template](https://github.com/victoresque/pytorch-template) by [@victoresque](https://github.com/victoresque). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base_model import * | ||
from .base_trainer import * | ||
from .base_dataset import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import torch | ||
import logging | ||
import gc | ||
from abc import ABCMeta, abstractmethod | ||
from typing import Callable, List | ||
from torch.utils.data import Dataset | ||
from torch_geometric.data import Data | ||
import random | ||
|
||
class BaseDataSet(Dataset): | ||
|
||
def __init__(self, root_dir: str, | ||
start_level: int, | ||
end_level: int, | ||
get_coords: bool = False, | ||
benchmark: bool = False, | ||
is_train: bool = True, | ||
debug_mode: bool = False, | ||
transform: List[Callable] = None, | ||
original_meshes_dir: str = None, | ||
sample_checker: List[Callable] = None, | ||
include_edges: bool = True): | ||
"""Base class for data set objects. | ||
Arguments: | ||
root_dir {str} -- Root path to folder containing train/test samples | ||
start_level {int} -- First hierarchy level of mesh | ||
end_level {int} -- Final level of mesh | ||
Keyword Arguments: | ||
get_coords {bool} -- load coordinates of vertices as well (default: {False}) | ||
benchmark {bool} -- benchmark mode (e.g. ScanNet test set) (default: {False}) | ||
is_train {bool} -- train or validation mode? (default: {True}) | ||
debug_mode {bool} -- also load vertex positions and colors of later levels (default: {False}) | ||
transform {List[Callable]} -- list of callable objects to transform data samples (data augmentation) (default: {None}) | ||
original_meshes_dir {str} -- root path of folder containing the original data set meshes (visualization and evaluation) (default: {None}) | ||
sample_checker {List[Callable]} -- list of callable objects to reject samples of low quality (e.g. number of unlabeled vertices) (default: {None}) | ||
include_edges {bool} -- only use point cloud information (default: {True}) | ||
""" | ||
|
||
self._root_dir = root_dir | ||
self._transform = transform | ||
self._start_level = start_level | ||
self._end_level = end_level | ||
self._debug_mode = debug_mode | ||
self._is_train = is_train | ||
self._original_meshes_dir = original_meshes_dir | ||
self._sample_checker = sample_checker | ||
self._benchmark = benchmark | ||
self._include_edges = include_edges | ||
|
||
self._get_coords = get_coords | ||
|
||
if self._sample_checker is None: | ||
self._sample_checker = [lambda x: True] | ||
|
||
self.index2filenames = self._load(self._is_train, self._benchmark) | ||
|
||
self.logger = logging.getLogger(self.__class__.__name__) | ||
|
||
def _load(self, is_train: bool, benchmark: bool) -> List[str]: | ||
"""Subclasses implement function which returns file names of train/val samples | ||
""" | ||
raise NotImplementedError("") | ||
|
||
def __getitem__(self, index: int) -> List[Data]: | ||
"""[summary] | ||
Arguments: | ||
index {int} -- [description] | ||
Raises: | ||
RuntimeError: [description] | ||
Returns: | ||
List[Data] -- [description] | ||
""" | ||
sample = None | ||
|
||
try: | ||
name = self.index2filenames[index] | ||
|
||
file_path = f"{self._root_dir}/{name}" | ||
|
||
saved_tensors = torch.load(file_path) | ||
|
||
coords = saved_tensors['vertices'][:self._end_level] | ||
|
||
if not self._benchmark: | ||
labels = saved_tensors['labels'] | ||
else: | ||
labels = None | ||
|
||
edges = saved_tensors['edges'][:self._end_level] | ||
|
||
if self._is_train: | ||
traces = saved_tensors['traces'][:self._end_level-1] | ||
else: | ||
trace_0 = saved_tensors['traces'][0] | ||
traces = saved_tensors['traces'][1:self._end_level] | ||
|
||
sample = Data(x=coords[0][:, 3:], | ||
pos=coords[0][:, :3], | ||
edge_index=edges[0].t().contiguous( | ||
) if self._include_edges else None, | ||
y=labels) | ||
sample.name = name | ||
|
||
nested_meshes = [] | ||
|
||
for level in range(1, len(edges)): | ||
data = Data(edge_index=edges[level].t( | ||
).contiguous() if self._include_edges else None) | ||
|
||
data.trace_index = traces[level-1] | ||
|
||
if self._debug_mode: | ||
data.x = coords[level][:, 3:] | ||
data.pos = coords[level][:, :3] | ||
|
||
if self._get_coords: | ||
data.pos = coords[level][:, :3] | ||
|
||
nested_meshes.append(data) | ||
|
||
sample.num_vertices = [] | ||
for level, nested_mesh in enumerate(nested_meshes): | ||
setattr( | ||
sample, f"hierarchy_edge_index_{level+1}", nested_mesh.edge_index) | ||
setattr( | ||
sample, f"hierarchy_trace_index_{level+1}", nested_mesh.trace_index) | ||
|
||
sample.num_vertices.append( | ||
int(sample[f"hierarchy_trace_index_{level+1}"].max() + 1)) | ||
|
||
if self._get_coords: | ||
setattr(sample, f"pos_{level + 1}", nested_mesh.pos) | ||
|
||
if not self._is_train: | ||
sample.original_index_traces = trace_0 | ||
|
||
if self._debug_mode: | ||
sample.all_levels = [sample] | ||
sample.all_levels.extend(nested_meshes) | ||
|
||
if self._transform: | ||
sample = self._transform(sample) | ||
|
||
for checker in self._sample_checker: | ||
if not checker(sample): | ||
raise RuntimeError( | ||
f"{checker.__class__.__name__} rejected the sample") | ||
|
||
return sample | ||
except Exception as e: | ||
if sample is not None: | ||
del sample | ||
gc.collect() | ||
self.logger.warning( | ||
f"Warning: Training example {index} could not be processed") | ||
self.logger.warning(f"{str(e)}") | ||
index = random.randrange(len(self)) | ||
return self.__getitem__(index) | ||
|
||
def __len__(self): | ||
return len(self.index2filenames) | ||
|
||
@property | ||
@abstractmethod | ||
def color_map(self): | ||
"""Dataset has to declare to which color each class is mapped for visualization purposes | ||
""" | ||
raise NotImplementedError("") | ||
|
||
@property | ||
def num_classes(self): | ||
return len(self.color_map) | ||
|
||
@property | ||
@abstractmethod | ||
def ignore_classes(self) -> int: | ||
"""There exist unlimited vertices. Specify ID to ignore them in calculation of loss and metrics.""" | ||
raise NotImplementedError("") | ||
|
||
pos_neg_map = torch.FloatTensor( | ||
[ | ||
[200, 200, 200], | ||
[0, 255, 0], | ||
[255, 0, 0]]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import logging | ||
import torch.nn as nn | ||
import numpy as np | ||
|
||
|
||
class BaseModel(nn.Module): | ||
""" | ||
Base class for all models | ||
""" | ||
def __init__(self): | ||
super(BaseModel, self).__init__() | ||
self.logger = logging.getLogger(self.__class__.__name__) | ||
|
||
def forward(self, *input): | ||
""" | ||
Forward pass logic | ||
:return: Model output | ||
""" | ||
raise NotImplementedError | ||
|
||
def summary(self): | ||
""" | ||
Model summary | ||
""" | ||
model_parameters = filter(lambda p: p.requires_grad, self.parameters()) | ||
params = sum([np.prod(p.size()) for p in model_parameters]) | ||
self.logger.info('Trainable parameters: {}'.format(params)) | ||
self.logger.info(self) | ||
|
||
def __str__(self): | ||
""" | ||
Model prints with number of trainable parameters | ||
""" | ||
model_parameters = filter(lambda p: p.requires_grad, self.parameters()) | ||
params = sum([np.prod(p.size()) for p in model_parameters]) | ||
return super(BaseModel, self).__str__() + '\nTrainable parameters: {}'.format(params) | ||
# print(super(BaseModel, self)) |
Oops, something went wrong.