forked from Dayan-Guan/USRN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
da98fdd
commit 24e1ef2
Showing
7 changed files
with
1,397 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,66 @@ | ||
# USRN | ||
Code for <Unbiased Subclass Regularization for Semi-Supervised Semantic Segmentation> in CVPR 2022 | ||
# [CVPR 2022] Unbiased Subclass Regularization for Semi-Supervised Semantic Segmentation | ||
|
||
### Updates | ||
- *03/2022*: Codes has been released! Two 2080Ti or single V100-32G are used for PASCAL VOC, Four 2080Ti or two V100-32G are used for Cityscapes. | ||
|
||
|
||
### Paper | ||
data:image/s3,"s3://crabby-images/45d8d/45d8d1ca91d23c10781f9f782be48b0e419a6f0c" alt="" | ||
|
||
[Unbiased Subclass Regularization for Semi-Supervised Semantic Segmentation](https://dayan-guan.github.io/pub/USRN.pdf) | ||
|
||
[Dayan Guan](https://scholar.google.com/citations?user=9jp9QAsAAAAJ&hl=en), [Jiaxing Huang](https://scholar.google.com/citations?user=czirNcwAAAAJ&hl=en&oi=ao), [Xiao Aoran](https://scholar.google.com/citations?user=yGKsEpAAAAAJ&hl=en), [Shijian Lu](https://scholar.google.com/citations?user=uYmK-A0AAAAJ&hl=en) | ||
School of Computer Science and Engineering, Nanyang Technological University, Singapore | ||
|
||
### Abstract | ||
Semi-supervised semantic segmentation learns from small amounts of labelled images and large amounts of unlabelled images, which has witnessed impressive progress with the recent advance of deep neural networks. However, it often suffers from severe class-bias problem while exploring the unlabelled images, largely due to the clear pixel-wise class imbalance in the labelled images. This paper presents an unbiased subclass regularization network (USRN) that alleviates the class imbalance issue by learning class-unbiased segmentation from balanced subclass distributions. We build the balanced subclass distributions by clustering pixels of each original class into multiple subclasses of similar sizes, which provide class-balanced pseudo supervision to regularize the class-biased segmentation. In addition, we design an entropy-based gate mechanism to coordinate learning between the original classes and the clustered subclasses which facilitates subclass regularization effectively by suppressing unconfident subclass predictions. Extensive experiments over multiple public benchmarks show that USRN achieves superior performance as compared with the state-of-the-art. | ||
|
||
### Preparation | ||
1. Environment: | ||
```bash | ||
sh init.sh | ||
``` | ||
|
||
2. dataset: | ||
* [PASCAL VOC](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar) with [SegmentationClassAug](https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0): | ||
```bash | ||
USRN/dataset/voc/VOCdevkit/VOC2012 % PASCAL VOC 2012 dataset root | ||
USRN/dataset/voc/VOCdevkit/VOC2012/JPEGImages % Images | ||
USRN/dataset/voc/VOCdevkit/VOC2012/SegmentationClass % Annotations | ||
USRN/dataset/voc/VOCdevkit/VOC2012/SegmentationClassAug % Extra annotations | ||
``` | ||
|
||
* [Cityscapes](https://www.cityscapes-dataset.com/) | ||
```bash | ||
USRN/dataset/cityscapes/ % cityscapes dataset root | ||
USRN/dataset/cityscapes/leftImg8bit_sequence % leftImg8bit_trainvaltest | ||
USRN/dataset/cityscapes/images % cp ../leftImg8bit_sequence/train/*/* ./images/train/ | ||
USRN/dataset/cityscapes/gtFine % gtFine_trainvaltest | ||
USRN/dataset/cityscapes/segmentation % cp ../gtFine_trainvaltest/train/*/* ./segmentation/train/ | ||
``` | ||
|
||
3. Pre-trained models: | ||
Download [pre-trained models](https://github.com/Dayan-Guan/USRN/releases/tag/Latest) and put in ```USRN/pretrained``` | ||
|
||
### Evaluation and Visualization using Pretrained Models | ||
* Baseline (1/32 split of PASCAL VOC): | ||
```bash | ||
python3 main.py --test True --resume pretrained/best_model_voc_1over32_baseline.pth --config configs/voc_1over32_baseline.json | ||
``` | ||
|
||
* USRN (1/32 split of PASCAL VOC): | ||
```bash | ||
python3 main.py --test True --resume pretrained/best_model_voc_1over32_usrn.pth --config configs/voc_1over32_usrn.json | ||
``` | ||
|
||
### Training and Testing | ||
* USRN (1/32 split of PASCAL VOC): | ||
```bash | ||
sh main.sh voc_1over32 | ||
``` | ||
|
||
## Acknowledgements | ||
This codebase is heavily borrowed from [CAC](https://github.com/dvlab-research/Context-Aware-Consistency). | ||
|
||
## Contact | ||
If you have any questions, please contact: [email protected] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
### conda environment | ||
conda create --name USRN --file requirements.txt | ||
|
||
### Balanced K-means | ||
git clone https://github.com/zhu-he/regularized-k-means.git | ||
cd regularized-k-means | ||
mkdir build | ||
cd build | ||
cmake .. -DCMAKE_BUILD_TYPE=Release | ||
cmake --build . | ||
cd .. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
import random | ||
import numpy as np | ||
import os | ||
import json | ||
import argparse | ||
import torch | ||
import dataloaders | ||
import models | ||
import math | ||
from utils import Logger | ||
from trainer import Test, Save_Features, Trainer_Baseline, Trainer_USRN | ||
import torch.nn.functional as F | ||
from utils.losses import abCE_loss, CE_loss, consistency_weight | ||
|
||
import torch.multiprocessing as mp | ||
import torch.distributed as dist | ||
|
||
# import warnings | ||
# warnings.filterwarnings("ignore") | ||
|
||
|
||
def get_instance(module, name, config, *args): | ||
# GET THE CORRESPONDING CLASS / FCT | ||
return getattr(module, config[name]['type'])(*args, **config[name]['args']) | ||
|
||
|
||
def main(gpu, ngpus_per_node, config, resume, test, save_feature): | ||
if gpu == 0: | ||
train_logger = Logger() | ||
else: | ||
train_logger = None | ||
|
||
config['rank'] = gpu + ngpus_per_node * config['n_node'] | ||
|
||
torch.cuda.set_device(gpu) | ||
assert config['train_supervised']['batch_size'] % config['n_gpu'] == 0 | ||
assert config['train_unsupervised']['batch_size'] % config['n_gpu'] == 0 | ||
assert config['val_loader']['batch_size'] % config['n_gpu'] == 0 | ||
config['train_supervised']['batch_size'] = int(config['train_supervised']['batch_size'] / config['n_gpu']) | ||
config['train_unsupervised']['batch_size'] = int(config['train_unsupervised']['batch_size'] / config['n_gpu']) | ||
config['val_loader']['batch_size'] = int(config['val_loader']['batch_size'] / config['n_gpu']) | ||
config['train_supervised']['num_workers'] = int(config['train_supervised']['num_workers'] / config['n_gpu']) | ||
config['train_unsupervised']['num_workers'] = int(config['train_unsupervised']['num_workers'] / config['n_gpu']) | ||
config['val_loader']['num_workers'] = int(config['val_loader']['num_workers'] / config['n_gpu']) | ||
dist.init_process_group(backend='nccl', init_method=config['dist_url'], world_size=config['world_size'], rank=config['rank']) | ||
|
||
seed = config['random_seed'] | ||
np.random.seed(seed) | ||
random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
torch.backends.cudnn.deterministic = True | ||
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. | ||
|
||
# DATA LOADERS | ||
config['train_supervised']['n_labeled_examples'] = config['n_labeled_examples'] | ||
config['train_unsupervised']['n_labeled_examples'] = config['n_labeled_examples'] | ||
config['train_unsupervised']['use_weak_lables'] = config['use_weak_lables'] | ||
config['train_supervised']['data_dir'] = config['data_dir'] | ||
config['train_unsupervised']['data_dir'] = config['data_dir'] | ||
config['val_loader']['data_dir'] = config['data_dir'] | ||
config['train_supervised']['datalist'] = config['datalist'] | ||
config['train_unsupervised']['datalist'] = config['datalist'] | ||
config['val_loader']['datalist'] = config['datalist'] | ||
|
||
iter_per_epoch = int(config['n_labeled_examples'] / config['train_supervised']['batch_size']) | ||
config['trainer']['iter_per_epoch'] = iter_per_epoch | ||
number_epochs = config['trainer']['epochs'] | ||
number_early_stop = config['trainer']['early_stop'] | ||
config['trainer']['epochs'] = int(config['num_images_all'] / config['n_labeled_examples']) * number_epochs | ||
config['trainer']['early_stop'] = int(config['num_images_all'] / config['n_labeled_examples']) * number_early_stop | ||
|
||
if test: | ||
if config['dataset'] == 'voc': | ||
sup_dataloader = dataloaders.VOC | ||
elif config['dataset'] == 'cityscapes': | ||
sup_dataloader = dataloaders.City | ||
else: | ||
if config['dataset'] == 'voc': | ||
sup_dataloader = dataloaders.VOC | ||
unsup_dataloader = dataloaders.PairVoc_StrongWeak | ||
sup_dataloader_SubCls = dataloaders.VOC_SubCls | ||
elif config['dataset'] == 'cityscapes': | ||
sup_dataloader = dataloaders.City | ||
unsup_dataloader = dataloaders.PairCity_StrongWeak | ||
sup_dataloader_SubCls = dataloaders.City_SubCls | ||
|
||
val_loader = sup_dataloader(config['val_loader']) | ||
|
||
config['model']['n_labeled_examples'] = config['n_labeled_examples'] | ||
config['model']['MEAN'] = val_loader.MEAN | ||
config['model']['STD'] = val_loader.STD | ||
|
||
if test: | ||
sup_loss = CE_loss | ||
model = models.Test(num_classes=val_loader.dataset.num_classes, conf=config['model'], | ||
sup_loss=sup_loss, ignore_index=val_loader.dataset.ignore_index) | ||
if gpu == 0: | ||
print(f'\n{model}\n') | ||
# TRAINING | ||
trainer = Test( | ||
model=model, | ||
resume=resume, | ||
config=config, | ||
val_loader=val_loader, | ||
iter_per_epoch=iter_per_epoch, | ||
train_logger=train_logger, | ||
gpu=gpu, | ||
test=test) | ||
elif save_feature: | ||
sup_loss = CE_loss | ||
model = models.Save_Features(num_classes=val_loader.dataset.num_classes, conf=config['model'], | ||
sup_loss=sup_loss, ignore_index=val_loader.dataset.ignore_index) | ||
if gpu == 0: | ||
print(f'\n{model}\n') | ||
# TRAINING | ||
trainer = Save_Features( | ||
model=model, | ||
resume=resume, | ||
config=config, | ||
val_loader=val_loader, | ||
iter_per_epoch=iter_per_epoch, | ||
train_logger=train_logger, | ||
gpu=gpu, | ||
test=test) | ||
elif config['name'] == 'USRN': | ||
config['train_supervised']['label_subcls'] = config['label_subcls'] | ||
supervised_loader = sup_dataloader_SubCls(config['train_supervised']) | ||
unsupervised_loader = unsup_dataloader(config['train_unsupervised']) | ||
sup_loss = CE_loss | ||
model = models.USRN(num_classes=val_loader.dataset.num_classes, conf=config['model'], | ||
sup_loss=sup_loss, ignore_index=val_loader.dataset.ignore_index) | ||
if gpu == 0: | ||
print(f'\n{model}\n') | ||
# TRAINING | ||
trainer = Trainer_USRN( | ||
model=model, | ||
resume=resume, | ||
config=config, | ||
supervised_loader=supervised_loader, | ||
unsupervised_loader=unsupervised_loader, | ||
val_loader=val_loader, | ||
iter_per_epoch=iter_per_epoch, | ||
train_logger=train_logger, | ||
gpu=gpu, | ||
test=test) | ||
elif config['name'] == 'Baseline': | ||
supervised_loader = sup_dataloader(config['train_supervised']) | ||
unsupervised_loader = unsup_dataloader(config['train_unsupervised']) | ||
sup_loss = CE_loss | ||
model = models.Baseline(num_classes=val_loader.dataset.num_classes, conf=config['model'], | ||
sup_loss=sup_loss, ignore_index=val_loader.dataset.ignore_index) | ||
if gpu == 0: | ||
print(f'\n{model}\n') | ||
# TRAINING | ||
trainer = Trainer_Baseline( | ||
model=model, | ||
resume=resume, | ||
config=config, | ||
supervised_loader=supervised_loader, | ||
unsupervised_loader=unsupervised_loader, | ||
val_loader=val_loader, | ||
iter_per_epoch=iter_per_epoch, | ||
train_logger=train_logger, | ||
gpu=gpu, | ||
test=test) | ||
trainer.train() | ||
|
||
|
||
def find_free_port(): | ||
import socket | ||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | ||
# Binding to port 0 will cause the OS to find an available port for us | ||
sock.bind(("", 0)) | ||
port = sock.getsockname()[1] | ||
sock.close() | ||
# NOTE: there is still a chance the port could be taken by other processes. | ||
return port | ||
|
||
|
||
if __name__ == '__main__': | ||
# PARSE THE ARGS | ||
parser = argparse.ArgumentParser(description='PyTorch Training') | ||
parser.add_argument('-c', '--config', default='configs/config.json', type=str, | ||
help='Path to the config file') | ||
parser.add_argument('-r', '--resume', default=None, type=str, | ||
help='Path to the .pth model checkpoint to resume training') | ||
parser.add_argument('-t', '--test', default=False, type=bool, | ||
help='whether to test') | ||
parser.add_argument('-sf', '--save_feature', default=False, type=bool, | ||
help='whether to test') | ||
args = parser.parse_args() | ||
|
||
config = json.load(open(args.config)) | ||
torch.backends.cudnn.benchmark = True | ||
port = find_free_port() | ||
config['dist_url'] = f"tcp://127.0.0.1:{port}" | ||
config['n_node'] = 0 # only support 1 node | ||
config['world_size'] = config['n_gpu'] | ||
mp.spawn(main, nprocs=config['n_gpu'], args=(config['n_gpu'], config, args.resume, args.test, args.save_feature)) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#!/bin/bash | ||
#data="voc_1over32" | ||
data=$1 | ||
|
||
## Train a supervised model with labelled images | ||
python3 main.py --config configs/${data}_baseline.json | ||
|
||
## Generate class-balanced subclass clusters | ||
python3 main.py --save_feature True --resume saved/${data}_baseline/best_model.pth --config configs/${data}_baseline.json | ||
python3 clustering.py --config configs/${data}_baseline.json --clustering_algorithm balanced_kmeans | ||
|
||
## Train a semi-supervised model with both labelled and unlabelled images | ||
python3 main.py --config configs/${data}_usrn.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# This file may be used to create an environment using: | ||
# $ conda create --name <env> --file <this file> | ||
# platform: linux-64 | ||
_libgcc_mutex=0.1=main | ||
_openmp_mutex=4.5=1_gnu | ||
absl-py=0.13.0=pypi_0 | ||
backcall=0.2.0=pypi_0 | ||
ca-certificates=2021.10.26=h06a4308_2 | ||
cachetools=4.2.2=pypi_0 | ||
certifi=2021.10.8=py38h06a4308_0 | ||
charset-normalizer=2.0.4=pypi_0 | ||
cycler=0.10.0=pypi_0 | ||
cython=0.29.24=pypi_0 | ||
decorator=5.1.0=pypi_0 | ||
dominate=2.6.0=pypi_0 | ||
future=0.18.2=pypi_0 | ||
google-auth=1.34.0=pypi_0 | ||
google-auth-oauthlib=0.4.5=pypi_0 | ||
grpcio=1.39.0=pypi_0 | ||
idna=3.2=pypi_0 | ||
imageio=2.9.0=pypi_0 | ||
ipdb=0.13.9=pypi_0 | ||
ipython=7.29.0=pypi_0 | ||
jedi=0.18.0=pypi_0 | ||
joblib=1.0.1=pypi_0 | ||
kiwisolver=1.3.1=pypi_0 | ||
ld_impl_linux-64=2.35.1=h7274673_9 | ||
libffi=3.3=he6710b0_2 | ||
libgcc-ng=9.3.0=h5101ec6_17 | ||
libgomp=9.3.0=h5101ec6_17 | ||
libstdcxx-ng=9.3.0=hd4cf53a_17 | ||
markdown=3.3.4=pypi_0 | ||
matplotlib=3.4.3=pypi_0 | ||
matplotlib-inline=0.1.3=pypi_0 | ||
ncurses=6.2=he6710b0_1 | ||
networkx=2.6.2=pypi_0 | ||
ninja=1.10.2=hff7bd54_1 | ||
nose=1.3.7=pypi_0 | ||
numpy=1.21.2=pypi_0 | ||
oauthlib=3.1.1=pypi_0 | ||
opencv-python=4.5.3.56=pypi_0 | ||
openssl=1.1.1l=h7f8727e_0 | ||
parso=0.8.2=pypi_0 | ||
pexpect=4.8.0=pypi_0 | ||
pickleshare=0.7.5=pypi_0 | ||
pillow=8.3.1=pypi_0 | ||
pip=21.0.1=py38h06a4308_0 | ||
portalocker=2.3.2=pypi_0 | ||
prompt-toolkit=3.0.22=pypi_0 | ||
protobuf=3.17.3=pypi_0 | ||
ptyprocess=0.7.0=pypi_0 | ||
pyasn1=0.4.8=pypi_0 | ||
pyasn1-modules=0.2.8=pypi_0 | ||
pydensecrf=1.0rc2=pypi_0 | ||
pygments=2.10.0=pypi_0 | ||
pyparsing=2.4.7=pypi_0 | ||
python=3.8.11=h12debd9_0_cpython | ||
python-dateutil=2.8.2=pypi_0 | ||
pywavelets=1.1.1=pypi_0 | ||
readline=8.1=h27cfd23_0 | ||
requests=2.26.0=pypi_0 | ||
requests-oauthlib=1.3.0=pypi_0 | ||
rsa=4.7.2=pypi_0 | ||
scikit-image=0.18.2=pypi_0 | ||
scikit-learn=0.24.0=pypi_0 | ||
scipy=1.7.1=pypi_0 | ||
setuptools=52.0.0=py38h06a4308_0 | ||
six=1.16.0=pypi_0 | ||
sklearn=0.0=pypi_0 | ||
sqlite=3.36.0=hc218d9a_0 | ||
tensorboard=2.6.0=pypi_0 | ||
tensorboard-data-server=0.6.1=pypi_0 | ||
tensorboard-plugin-wit=1.8.0=pypi_0 | ||
threadpoolctl=2.2.0=pypi_0 | ||
tifffile=2021.8.8=pypi_0 | ||
tk=8.6.10=hbc83047_0 | ||
toml=0.10.2=pypi_0 | ||
torch=1.6.0=pypi_0 | ||
torch-encoding=1.2.2b20211111=pypi_0 | ||
torchvision=0.7.0=pypi_0 | ||
tqdm=4.62.1=pypi_0 | ||
traitlets=5.1.1=pypi_0 | ||
urllib3=1.26.6=pypi_0 | ||
wcwidth=0.2.5=pypi_0 | ||
werkzeug=2.0.1=pypi_0 | ||
wheel=0.37.0=pyhd3eb1b0_0 | ||
xz=5.2.5=h7b6447c_0 | ||
zlib=1.2.11=h7b6447c_3 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.