Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Semi-supervised Oriented Object Detection: SOOD (CVPR 2023) #1003

Open
wants to merge 5 commits into
base: dev-1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions configs/_base_/datasets/semi_dotav15_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
custom_imports = dict(
imports=['mmpretrain.datasets.transforms'], allow_failed_imports=False)

# dataset settings
dataset_type = 'DOTAv15Dataset'
data_root = 'data/split_ss_dota1_5/'
backend_args = None

branch_field = ['sup', 'unsup_teacher', 'unsup_student']
# pipeline used to augment labeled data,
# which will be sent to student model for supervised training.
sup_pipeline = [
dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
dict(
type='mmdet.RandomFlip',
prob=0.75,
direction=['horizontal', 'vertical', 'diagonal']),
# dict(type='mmdet.FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)),
dict(type='mmdet.Pad', size_divisor=32, pad_val=dict(img=(114, 114, 114))),
dict(
type='mmdet.MultiBranch',
branch_field=branch_field,
sup=dict(type='mmdet.PackDetInputs'))
]

# pipeline used to augment unlabeled data weakly,
# which will be sent to teacher model for predicting pseudo instances.
weak_pipeline = [
dict(type='mmdet.Pad', size_divisor=32, pad_val=dict(img=(114, 114, 114))),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction',
'homography_matrix')),
]

# pipeline used to augment unlabeled data strongly,
# which will be sent to student model for unsupervised training.
strong_pipeline = [
dict(
type='RandomApply',
transforms=dict(
type='mmpretrain.ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.1),
prob=0.8),
dict(type='mmpretrain.RandomGrayscale', prob=0.2, keep_channels=True),
dict(
type='mmpretrain.GaussianBlur',
radius=None,
prob=0.5,
magnitude_level=1.9,
magnitude_range=[0.1, 2.0],
magnitude_std='inf',
total_level=1.9),
dict(type='mmdet.Pad', size_divisor=32, pad_val=dict(img=(114, 114, 114))),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction',
'homography_matrix')),
]

# pipeline used to augment unlabeled data into different views
unsup_pipeline = [
dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
dict(type='mmdet.LoadEmptyAnnotations'),
dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
dict(
type='mmdet.RandomFlip',
prob=0.75,
direction=['horizontal', 'vertical', 'diagonal']),
dict(
type='mmdet.MultiBranch',
branch_field=branch_field,
unsup_teacher=weak_pipeline,
unsup_student=strong_pipeline,
)
]

val_pipeline = [
dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]

test_pipeline = [
dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]

batch_size = 3
num_workers = 6
# There are two common semi-supervised learning settings on the coco dataset:
# (1) Divide the train2017 into labeled and unlabeled datasets
# by a fixed percentage, such as 1%, 2%, 5% and 10%.
# The format of labeled_ann_file and unlabeled_ann_file are
# instances_train2017.{fold}@{percent}.json, and
# instances_train2017.{fold}@{percent}-unlabeled.json
# `fold` is used for cross-validation, and `percent` represents
# the proportion of labeled data in the train2017.
# (2) Choose the train2017 as the labeled dataset
# and unlabeled2017 as the unlabeled dataset.
# The labeled_ann_file and unlabeled_ann_file are
# instances_train2017.json and image_info_unlabeled2017.json
# We use this configuration by default.
labeled_dataset = dict(
type=dataset_type,
data_root=data_root,
ann_file='train_10_labeled/annfiles',
data_prefix=dict(img_path='train_10_labeled/images/'),
filter_cfg=dict(filter_empty_gt=True),
pipeline=sup_pipeline)

unlabeled_dataset = dict(
type=dataset_type,
data_root=data_root,
ann_file='train_10_unlabeled/empty_annfiles/',
data_prefix=dict(img_path='train_10_unlabeled/images/'),
filter_cfg=dict(filter_empty_gt=False),
pipeline=unsup_pipeline)

train_dataloader = dict(
batch_size=batch_size,
num_workers=num_workers,
persistent_workers=True,
sampler=dict(
type='mmdet.MultiSourceSampler',
batch_size=batch_size,
source_ratio=[2, 1]),
dataset=dict(
type='ConcatDataset', datasets=[labeled_dataset, unlabeled_dataset]))

val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='val/annfiles/',
data_prefix=dict(img_path='val/images/'),
test_mode=True,
pipeline=val_pipeline))

test_dataloader = val_dataloader

val_evaluator = dict(type='DOTAMetric', metric='mAP')

test_evaluator = val_evaluator
96 changes: 96 additions & 0 deletions configs/sood/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# SOOD

> [SOOD: Towards Semi-Supervised Oriented Object Detection](https://arxiv.org/abs/2304.04515)

<!-- [ALGORITHM] -->

## Abstract

Semi-Supervised Object Detection (SSOD), aiming to explore unlabeled data for boosting object detectors, has become an active task in recent years. However, existing SSOD approaches mainly focus on horizontal objects, leaving multi-oriented objects that are common in aerial images unexplored. This paper proposes a novel Semi-supervised Oriented Object Detection model, termed SOOD, built upon the mainstream pseudo-labeling framework. Towards oriented objects in aerial scenes, we design two loss functions to provide better supervision. Focusing on the orientations of objects, the first loss regularizes the consistency between each pseudo-label-prediction pair (includes a prediction and its corresponding pseudo label) with adaptive weights based on their orientation gap. Focusing on the layout of an image, the second loss regularizes the similarity and explicitly builds the many-to-many relation between the sets of pseudo-labels and predictions. Such a global consistency constraint can further boost semi-supervised learning. Our experiments show that when trained with the two proposed losses, SOOD surpasses the state-of-the-art SSOD methods under various settings on the DOTA-v1.5 benchmark.

## Requirements

- `mmpretrain>=1.0.0`
please refer to [mmpretrain](https://mmpretrain.readthedocs.io/en/latest/get_started.html) for installation.

## Data Preparation

Please refer to [data_preparation.md](tools/data/dota/README.md) to prepare the original data. After that, the data folder should be organized as follows:

```
├── data
│ ├── split_ss_dota1_5
│ │ ├── train
│ │ │ ├── images
│ │ │ ├── annfiles
│ │ ├── val
│ │ │ ├── images
│ │ │ ├── annfiles
│ │ ├── test
│ │ │ ├── images
│ │ │ ├── annfiles
```

For partial labeled setting, we split the DOTA-v1.5's train set via the author released [split data list](tools/misc/split_dota1.5_lists) and [split tool](tools/misc/split_dota1.5_via_lists.py)

```angular2html
python tools/misc/split_dota1.5_via_lists.py
```

For fully labeled setting, we use DOTA-V1.5 train as labeled set and DOTA-V1.5 test as unlabeled set.

After that, the data folder should be organized as follows:

```
├── data
│ ├── split_ss_dota1_5
│ │ ├── train
│ │ │ ├── images
│ │ │ ├── annfiles
│ │ ├── train_10_labeled
│ │ │ ├── images
│ │ │ ├── annfiles
│ │ ├── train_10_unlabeled
│ │ │ ├── images
│ │ │ ├── annfiles
│ │ ├── train_20_labeled
│ │ │ ├── images
│ │ │ ├── annfiles
│ │ ├── train_20_unlabeled
│ │ │ ├── images
│ │ │ ├── annfiles
│ │ ├── train_30_labeled
│ │ │ ├── images
│ │ │ ├── annfiles
│ │ ├── train_30_unlabeled
│ │ │ ├── images
│ │ │ ├── annfiles
│ │ ├── val
│ │ │ ├── images
│ │ │ ├── annfiles
│ │ ├── test
│ │ │ ├── images
│ │ │ ├── annfiles
```

## Results

DOTA1.5

| Backbone | Setting | mAP50 | Mem (GB) | Config |
| :----------------------: | :-----: | :---: | :------: | :-------------------------------------------------------------: |
| ResNet50 (1024,1024,200) | 10% | | 8.45 | [config](./sood_fcos_r50_fpn_2xb3-180000k_semi-0.1-dotav1.5.py) |
| ResNet50 (1024,1024,200) | 20% | | | [config](./sood_fcos_r50_fpn_2xb3-180000k_semi-0.2-dotav1.5.py) |
| ResNet50 (1024,1024,200) | 30% | | | [config](./sood_fcos_r50_fpn_2xb3-180000k_semi-0.3-dotav1.5.py) |

## Citation

```
@inproceedings{hua2023sood,
title={SOOD: Towards Semi-Supervised Oriented Object Detection},
author={Hua, Wei and Liang, Dingkang and Li, Jingyu and Liu, Xiaolong and Zou, Zhikang and Ye, Xiaoqing and Bai, Xiang},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={15558--15567},
year={2023}
}
```
65 changes: 65 additions & 0 deletions configs/sood/rotated-fcos-le90_r50_fpn_dotav15.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""copy from rotated fcos."""
angle_version = 'le90'

# model settings
model = dict(
type='mmdet.FCOS',
data_preprocessor=dict(
type='mmdet.DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32,
boxtype2tensor=False),
backbone=dict(
type='mmdet.ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='mmdet.FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5,
relu_before_extra_convs=True),
bbox_head=dict(
type='RotatedFCOSHead',
num_classes=16,
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
center_sampling=True,
center_sample_radius=1.5,
norm_on_bbox=True,
centerness_on_reg=True,
use_hbbox_loss=False,
scale_angle=True,
bbox_coder=dict(
type='DistanceAnglePointCoder', angle_version=angle_version),
loss_cls=dict(
type='mmdet.FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='RotatedIoULoss', loss_weight=1.0),
loss_angle=None,
loss_centerness=dict(
type='mmdet.CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
),
# training and testing settings
train_cfg=None,
test_cfg=dict(
nms_pre=2000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms_rotated', iou_threshold=0.1),
max_per_img=2000))
Loading