Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Dayan-Guan authored Mar 19, 2022
1 parent da98fdd commit 24e1ef2
Show file tree
Hide file tree
Showing 7 changed files with 1,397 additions and 2 deletions.
68 changes: 66 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,66 @@
# USRN
Code for <Unbiased Subclass Regularization for Semi-Supervised Semantic Segmentation> in CVPR 2022
# [CVPR 2022] Unbiased Subclass Regularization for Semi-Supervised Semantic Segmentation

### Updates
- *03/2022*: Codes has been released! Two 2080Ti or single V100-32G are used for PASCAL VOC, Four 2080Ti or two V100-32G are used for Cityscapes.


### Paper
![](./teaser.jpg)

[Unbiased Subclass Regularization for Semi-Supervised Semantic Segmentation](https://dayan-guan.github.io/pub/USRN.pdf)

[Dayan Guan](https://scholar.google.com/citations?user=9jp9QAsAAAAJ&hl=en), [Jiaxing Huang](https://scholar.google.com/citations?user=czirNcwAAAAJ&hl=en&oi=ao), [Xiao Aoran](https://scholar.google.com/citations?user=yGKsEpAAAAAJ&hl=en), [Shijian Lu](https://scholar.google.com/citations?user=uYmK-A0AAAAJ&hl=en)
School of Computer Science and Engineering, Nanyang Technological University, Singapore

### Abstract
Semi-supervised semantic segmentation learns from small amounts of labelled images and large amounts of unlabelled images, which has witnessed impressive progress with the recent advance of deep neural networks. However, it often suffers from severe class-bias problem while exploring the unlabelled images, largely due to the clear pixel-wise class imbalance in the labelled images. This paper presents an unbiased subclass regularization network (USRN) that alleviates the class imbalance issue by learning class-unbiased segmentation from balanced subclass distributions. We build the balanced subclass distributions by clustering pixels of each original class into multiple subclasses of similar sizes, which provide class-balanced pseudo supervision to regularize the class-biased segmentation. In addition, we design an entropy-based gate mechanism to coordinate learning between the original classes and the clustered subclasses which facilitates subclass regularization effectively by suppressing unconfident subclass predictions. Extensive experiments over multiple public benchmarks show that USRN achieves superior performance as compared with the state-of-the-art.

### Preparation
1. Environment:
```bash
sh init.sh
```

2. dataset:
* [PASCAL VOC](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar) with [SegmentationClassAug](https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0):
```bash
USRN/dataset/voc/VOCdevkit/VOC2012 % PASCAL VOC 2012 dataset root
USRN/dataset/voc/VOCdevkit/VOC2012/JPEGImages % Images
USRN/dataset/voc/VOCdevkit/VOC2012/SegmentationClass % Annotations
USRN/dataset/voc/VOCdevkit/VOC2012/SegmentationClassAug % Extra annotations
```

* [Cityscapes](https://www.cityscapes-dataset.com/)
```bash
USRN/dataset/cityscapes/ % cityscapes dataset root
USRN/dataset/cityscapes/leftImg8bit_sequence % leftImg8bit_trainvaltest
USRN/dataset/cityscapes/images % cp ../leftImg8bit_sequence/train/*/* ./images/train/
USRN/dataset/cityscapes/gtFine % gtFine_trainvaltest
USRN/dataset/cityscapes/segmentation % cp ../gtFine_trainvaltest/train/*/* ./segmentation/train/
```

3. Pre-trained models:
Download [pre-trained models](https://github.com/Dayan-Guan/USRN/releases/tag/Latest) and put in ```USRN/pretrained```

### Evaluation and Visualization using Pretrained Models
* Baseline (1/32 split of PASCAL VOC):
```bash
python3 main.py --test True --resume pretrained/best_model_voc_1over32_baseline.pth --config configs/voc_1over32_baseline.json
```

* USRN (1/32 split of PASCAL VOC):
```bash
python3 main.py --test True --resume pretrained/best_model_voc_1over32_usrn.pth --config configs/voc_1over32_usrn.json
```

### Training and Testing
* USRN (1/32 split of PASCAL VOC):
```bash
sh main.sh voc_1over32
```

## Acknowledgements
This codebase is heavily borrowed from [CAC](https://github.com/dvlab-research/Context-Aware-Consistency).

## Contact
If you have any questions, please contact: [email protected]
11 changes: 11 additions & 0 deletions init.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
### conda environment
conda create --name USRN --file requirements.txt

### Balanced K-means
git clone https://github.com/zhu-he/regularized-k-means.git
cd regularized-k-means
mkdir build
cd build
cmake .. -DCMAKE_BUILD_TYPE=Release
cmake --build .
cd ..
202 changes: 202 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import random
import numpy as np
import os
import json
import argparse
import torch
import dataloaders
import models
import math
from utils import Logger
from trainer import Test, Save_Features, Trainer_Baseline, Trainer_USRN
import torch.nn.functional as F
from utils.losses import abCE_loss, CE_loss, consistency_weight

import torch.multiprocessing as mp
import torch.distributed as dist

# import warnings
# warnings.filterwarnings("ignore")


def get_instance(module, name, config, *args):
# GET THE CORRESPONDING CLASS / FCT
return getattr(module, config[name]['type'])(*args, **config[name]['args'])


def main(gpu, ngpus_per_node, config, resume, test, save_feature):
if gpu == 0:
train_logger = Logger()
else:
train_logger = None

config['rank'] = gpu + ngpus_per_node * config['n_node']

torch.cuda.set_device(gpu)
assert config['train_supervised']['batch_size'] % config['n_gpu'] == 0
assert config['train_unsupervised']['batch_size'] % config['n_gpu'] == 0
assert config['val_loader']['batch_size'] % config['n_gpu'] == 0
config['train_supervised']['batch_size'] = int(config['train_supervised']['batch_size'] / config['n_gpu'])
config['train_unsupervised']['batch_size'] = int(config['train_unsupervised']['batch_size'] / config['n_gpu'])
config['val_loader']['batch_size'] = int(config['val_loader']['batch_size'] / config['n_gpu'])
config['train_supervised']['num_workers'] = int(config['train_supervised']['num_workers'] / config['n_gpu'])
config['train_unsupervised']['num_workers'] = int(config['train_unsupervised']['num_workers'] / config['n_gpu'])
config['val_loader']['num_workers'] = int(config['val_loader']['num_workers'] / config['n_gpu'])
dist.init_process_group(backend='nccl', init_method=config['dist_url'], world_size=config['world_size'], rank=config['rank'])

seed = config['random_seed']
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.

# DATA LOADERS
config['train_supervised']['n_labeled_examples'] = config['n_labeled_examples']
config['train_unsupervised']['n_labeled_examples'] = config['n_labeled_examples']
config['train_unsupervised']['use_weak_lables'] = config['use_weak_lables']
config['train_supervised']['data_dir'] = config['data_dir']
config['train_unsupervised']['data_dir'] = config['data_dir']
config['val_loader']['data_dir'] = config['data_dir']
config['train_supervised']['datalist'] = config['datalist']
config['train_unsupervised']['datalist'] = config['datalist']
config['val_loader']['datalist'] = config['datalist']

iter_per_epoch = int(config['n_labeled_examples'] / config['train_supervised']['batch_size'])
config['trainer']['iter_per_epoch'] = iter_per_epoch
number_epochs = config['trainer']['epochs']
number_early_stop = config['trainer']['early_stop']
config['trainer']['epochs'] = int(config['num_images_all'] / config['n_labeled_examples']) * number_epochs
config['trainer']['early_stop'] = int(config['num_images_all'] / config['n_labeled_examples']) * number_early_stop

if test:
if config['dataset'] == 'voc':
sup_dataloader = dataloaders.VOC
elif config['dataset'] == 'cityscapes':
sup_dataloader = dataloaders.City
else:
if config['dataset'] == 'voc':
sup_dataloader = dataloaders.VOC
unsup_dataloader = dataloaders.PairVoc_StrongWeak
sup_dataloader_SubCls = dataloaders.VOC_SubCls
elif config['dataset'] == 'cityscapes':
sup_dataloader = dataloaders.City
unsup_dataloader = dataloaders.PairCity_StrongWeak
sup_dataloader_SubCls = dataloaders.City_SubCls

val_loader = sup_dataloader(config['val_loader'])

config['model']['n_labeled_examples'] = config['n_labeled_examples']
config['model']['MEAN'] = val_loader.MEAN
config['model']['STD'] = val_loader.STD

if test:
sup_loss = CE_loss
model = models.Test(num_classes=val_loader.dataset.num_classes, conf=config['model'],
sup_loss=sup_loss, ignore_index=val_loader.dataset.ignore_index)
if gpu == 0:
print(f'\n{model}\n')
# TRAINING
trainer = Test(
model=model,
resume=resume,
config=config,
val_loader=val_loader,
iter_per_epoch=iter_per_epoch,
train_logger=train_logger,
gpu=gpu,
test=test)
elif save_feature:
sup_loss = CE_loss
model = models.Save_Features(num_classes=val_loader.dataset.num_classes, conf=config['model'],
sup_loss=sup_loss, ignore_index=val_loader.dataset.ignore_index)
if gpu == 0:
print(f'\n{model}\n')
# TRAINING
trainer = Save_Features(
model=model,
resume=resume,
config=config,
val_loader=val_loader,
iter_per_epoch=iter_per_epoch,
train_logger=train_logger,
gpu=gpu,
test=test)
elif config['name'] == 'USRN':
config['train_supervised']['label_subcls'] = config['label_subcls']
supervised_loader = sup_dataloader_SubCls(config['train_supervised'])
unsupervised_loader = unsup_dataloader(config['train_unsupervised'])
sup_loss = CE_loss
model = models.USRN(num_classes=val_loader.dataset.num_classes, conf=config['model'],
sup_loss=sup_loss, ignore_index=val_loader.dataset.ignore_index)
if gpu == 0:
print(f'\n{model}\n')
# TRAINING
trainer = Trainer_USRN(
model=model,
resume=resume,
config=config,
supervised_loader=supervised_loader,
unsupervised_loader=unsupervised_loader,
val_loader=val_loader,
iter_per_epoch=iter_per_epoch,
train_logger=train_logger,
gpu=gpu,
test=test)
elif config['name'] == 'Baseline':
supervised_loader = sup_dataloader(config['train_supervised'])
unsupervised_loader = unsup_dataloader(config['train_unsupervised'])
sup_loss = CE_loss
model = models.Baseline(num_classes=val_loader.dataset.num_classes, conf=config['model'],
sup_loss=sup_loss, ignore_index=val_loader.dataset.ignore_index)
if gpu == 0:
print(f'\n{model}\n')
# TRAINING
trainer = Trainer_Baseline(
model=model,
resume=resume,
config=config,
supervised_loader=supervised_loader,
unsupervised_loader=unsupervised_loader,
val_loader=val_loader,
iter_per_epoch=iter_per_epoch,
train_logger=train_logger,
gpu=gpu,
test=test)
trainer.train()


def find_free_port():
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port


if __name__ == '__main__':
# PARSE THE ARGS
parser = argparse.ArgumentParser(description='PyTorch Training')
parser.add_argument('-c', '--config', default='configs/config.json', type=str,
help='Path to the config file')
parser.add_argument('-r', '--resume', default=None, type=str,
help='Path to the .pth model checkpoint to resume training')
parser.add_argument('-t', '--test', default=False, type=bool,
help='whether to test')
parser.add_argument('-sf', '--save_feature', default=False, type=bool,
help='whether to test')
args = parser.parse_args()

config = json.load(open(args.config))
torch.backends.cudnn.benchmark = True
port = find_free_port()
config['dist_url'] = f"tcp://127.0.0.1:{port}"
config['n_node'] = 0 # only support 1 node
config['world_size'] = config['n_gpu']
mp.spawn(main, nprocs=config['n_gpu'], args=(config['n_gpu'], config, args.resume, args.test, args.save_feature))


13 changes: 13 additions & 0 deletions main.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
#data="voc_1over32"
data=$1

## Train a supervised model with labelled images
python3 main.py --config configs/${data}_baseline.json

## Generate class-balanced subclass clusters
python3 main.py --save_feature True --resume saved/${data}_baseline/best_model.pth --config configs/${data}_baseline.json
python3 clustering.py --config configs/${data}_baseline.json --clustering_algorithm balanced_kmeans

## Train a semi-supervised model with both labelled and unlabelled images
python3 main.py --config configs/${data}_usrn.json
88 changes: 88 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: linux-64
_libgcc_mutex=0.1=main
_openmp_mutex=4.5=1_gnu
absl-py=0.13.0=pypi_0
backcall=0.2.0=pypi_0
ca-certificates=2021.10.26=h06a4308_2
cachetools=4.2.2=pypi_0
certifi=2021.10.8=py38h06a4308_0
charset-normalizer=2.0.4=pypi_0
cycler=0.10.0=pypi_0
cython=0.29.24=pypi_0
decorator=5.1.0=pypi_0
dominate=2.6.0=pypi_0
future=0.18.2=pypi_0
google-auth=1.34.0=pypi_0
google-auth-oauthlib=0.4.5=pypi_0
grpcio=1.39.0=pypi_0
idna=3.2=pypi_0
imageio=2.9.0=pypi_0
ipdb=0.13.9=pypi_0
ipython=7.29.0=pypi_0
jedi=0.18.0=pypi_0
joblib=1.0.1=pypi_0
kiwisolver=1.3.1=pypi_0
ld_impl_linux-64=2.35.1=h7274673_9
libffi=3.3=he6710b0_2
libgcc-ng=9.3.0=h5101ec6_17
libgomp=9.3.0=h5101ec6_17
libstdcxx-ng=9.3.0=hd4cf53a_17
markdown=3.3.4=pypi_0
matplotlib=3.4.3=pypi_0
matplotlib-inline=0.1.3=pypi_0
ncurses=6.2=he6710b0_1
networkx=2.6.2=pypi_0
ninja=1.10.2=hff7bd54_1
nose=1.3.7=pypi_0
numpy=1.21.2=pypi_0
oauthlib=3.1.1=pypi_0
opencv-python=4.5.3.56=pypi_0
openssl=1.1.1l=h7f8727e_0
parso=0.8.2=pypi_0
pexpect=4.8.0=pypi_0
pickleshare=0.7.5=pypi_0
pillow=8.3.1=pypi_0
pip=21.0.1=py38h06a4308_0
portalocker=2.3.2=pypi_0
prompt-toolkit=3.0.22=pypi_0
protobuf=3.17.3=pypi_0
ptyprocess=0.7.0=pypi_0
pyasn1=0.4.8=pypi_0
pyasn1-modules=0.2.8=pypi_0
pydensecrf=1.0rc2=pypi_0
pygments=2.10.0=pypi_0
pyparsing=2.4.7=pypi_0
python=3.8.11=h12debd9_0_cpython
python-dateutil=2.8.2=pypi_0
pywavelets=1.1.1=pypi_0
readline=8.1=h27cfd23_0
requests=2.26.0=pypi_0
requests-oauthlib=1.3.0=pypi_0
rsa=4.7.2=pypi_0
scikit-image=0.18.2=pypi_0
scikit-learn=0.24.0=pypi_0
scipy=1.7.1=pypi_0
setuptools=52.0.0=py38h06a4308_0
six=1.16.0=pypi_0
sklearn=0.0=pypi_0
sqlite=3.36.0=hc218d9a_0
tensorboard=2.6.0=pypi_0
tensorboard-data-server=0.6.1=pypi_0
tensorboard-plugin-wit=1.8.0=pypi_0
threadpoolctl=2.2.0=pypi_0
tifffile=2021.8.8=pypi_0
tk=8.6.10=hbc83047_0
toml=0.10.2=pypi_0
torch=1.6.0=pypi_0
torch-encoding=1.2.2b20211111=pypi_0
torchvision=0.7.0=pypi_0
tqdm=4.62.1=pypi_0
traitlets=5.1.1=pypi_0
urllib3=1.26.6=pypi_0
wcwidth=0.2.5=pypi_0
werkzeug=2.0.1=pypi_0
wheel=0.37.0=pyhd3eb1b0_0
xz=5.2.5=h7b6447c_0
zlib=1.2.11=h7b6447c_3
Binary file added teaser.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 24e1ef2

Please sign in to comment.