Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements Feature Pyramid Network #75

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions robosat/fpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Feature Pyramid Network (FPN) on top of ResNet. Comes with task-specific heads on top of it.

See:
- https://arxiv.org/abs/1612.03144 - Feature Pyramid Networks for Object Detection
- http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf - A Unified Architecture for Instance
and Semantic Segmentation

"""

import torch
import torch.nn as nn

from torchvision.models import resnet50


class FPN(nn.Module):
"""Feature Pyramid Network (FPN): top-down architecture with lateral connections.
Can be used as feature extractor for object detection or segmentation.
"""

def __init__(self, num_filters=256, pretrained=True):
"""Creates an `FPN` instance for feature extraction.

Args:
num_filters: the number of filters in each output pyramid level
pretrained: use ImageNet pre-trained backbone feature extractor
"""

super().__init__()

self.resnet = resnet50(pretrained=pretrained)

# Access resnet directly in forward pass; do not store refs here due to
# https://github.com/pytorch/pytorch/issues/8392

self.lateral4 = Conv1x1(2048, num_filters)
self.lateral3 = Conv1x1(1024, num_filters)
self.lateral2 = Conv1x1(512, num_filters)
self.lateral1 = Conv1x1(256, num_filters)

self.smooth4 = Conv3x3(num_filters, num_filters)
self.smooth3 = Conv3x3(num_filters, num_filters)
self.smooth2 = Conv3x3(num_filters, num_filters)
self.smooth1 = Conv3x3(num_filters, num_filters)

def forward(self, x):
# Bottom-up pathway, from ResNet

size = x.size()
assert size[-1] % 32 == 0 and size[-2] % 32 == 0, "image resolution has to be divisible by 32 for resnet"

enc0 = self.resnet.conv1(x)
enc0 = self.resnet.bn1(enc0)
enc0 = self.resnet.relu(enc0)
enc0 = self.resnet.maxpool(enc0)

enc1 = self.resnet.layer1(enc0)
enc2 = self.resnet.layer2(enc1)
enc3 = self.resnet.layer3(enc2)
enc4 = self.resnet.layer4(enc3)

# Lateral connections

lateral4 = self.lateral4(enc4)
lateral3 = self.lateral3(enc3)
lateral2 = self.lateral2(enc2)
lateral1 = self.lateral1(enc1)

# Top-down pathway

map4 = lateral4
map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")
map2 = lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest")
map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")

# Reduce aliasing effect of upsampling

map4 = self.smooth4(map4)
map3 = self.smooth3(map3)
map2 = self.smooth2(map2)
map1 = self.smooth1(map1)

return map1, map2, map3, map4


class FPNSegmentation(nn.Module):
"""Semantic segmentation model on top of a Feature Pyramid Network (FPN).
"""

def __init__(self, num_classes, num_filters=128, num_filters_fpn=256, pretrained=True):
"""Creates an `FPNSegmentation` instance for feature extraction.

Args:
num_classes: number of classes to predict
num_filters: the number of filters in each segmentation head pyramid level
num_filters_fpn: the number of filters in each FPN output pyramid level
pretrained: use ImageNet pre-trained backbone feature extractor
"""

super().__init__()

# Feature Pyramid Network (FPN) with four feature maps of resolutions
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.

self.fpn = FPN(num_filters=num_filters_fpn, pretrained=pretrained)

# The segmentation heads on top of the FPN

self.head1 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
self.head2 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
self.head3 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
self.head4 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))

self.final = nn.Conv2d(4 * num_filters, num_classes, kernel_size=3, padding=1)

def forward(self, x):
map1, map2, map3, map4 = self.fpn(x)

map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
map1 = self.head1(map1)

final = self.final(torch.cat([map4, map3, map2, map1], dim=1))

return nn.functional.upsample(final, scale_factor=4, mode="bilinear", align_corners=False)


class Conv1x1(nn.Module):
def __init__(self, num_in, num_out):
super().__init__()
self.block = nn.Conv2d(num_in, num_out, kernel_size=1, bias=False)

def forward(self, x):
return self.block(x)


class Conv3x3(nn.Module):
def __init__(self, num_in, num_out):
super().__init__()
self.block = nn.Conv2d(num_in, num_out, kernel_size=3, padding=1, bias=False)

def forward(self, x):
return self.block(x)
4 changes: 2 additions & 2 deletions robosat/tools/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.autograd

from robosat.config import load_config
from robosat.unet import UNet
from robosat.fpn import FPNSegmentation


def add_parser(subparser):
Expand All @@ -25,7 +25,7 @@ def main(args):
dataset = load_config(args.dataset)

num_classes = len(dataset["common"]["classes"])
net = UNet(num_classes)
net = FPNSegmentation(num_classes)

def map_location(storage, _):
return storage.cpu()
Expand Down
5 changes: 3 additions & 2 deletions robosat/tools/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from PIL import Image

from robosat.datasets import BufferedSlippyMapDirectory
from robosat.unet import UNet
from robosat.fpn import FPNSegmentation
from robosat.config import load_config
from robosat.colors import continuous_palette_for_color
from robosat.transforms import ConvertImageMode, ImageToTensor
Expand Down Expand Up @@ -59,8 +59,9 @@ def map_location(storage, _):
# https://github.com/pytorch/pytorch/issues/7178
chkpt = torch.load(args.checkpoint, map_location=map_location)

net = UNet(num_classes).to(device)
net = FPNSegmentation(num_classes)
net = nn.DataParallel(net)
net = net.to(device)

if cuda:
torch.backends.cudnn.benchmark = True
Expand Down
5 changes: 3 additions & 2 deletions robosat/tools/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flask import Flask, send_file, render_template, abort

from robosat.tiles import fetch_image
from robosat.unet import UNet
from robosat.fpn import FPNSegmentation
from robosat.config import load_config
from robosat.colors import make_palette
from robosat.transforms import ConvertImageMode, ImageToTensor
Expand Down Expand Up @@ -180,8 +180,9 @@ def map_location(storage, _):

num_classes = len(self.dataset["common"]["classes"])

net = UNet(num_classes).to(self.device)
net = FPNSegmentation(num_classes)
net = nn.DataParallel(net)
net = net.to(self.device)

if self.cuda:
torch.backends.cudnn.benchmark = True
Expand Down
8 changes: 4 additions & 4 deletions robosat/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from PIL import Image

import torch
import torch.nn as nn
import torch.backends.cudnn
from torch.nn import DataParallel
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.transforms import Resize, CenterCrop, Normalize
Expand All @@ -27,7 +27,7 @@
from robosat.datasets import SlippyMapTilesConcatenation
from robosat.metrics import Metrics
from robosat.losses import CrossEntropyLoss2d, mIoULoss2d, FocalLoss2d, LovaszLoss2d
from robosat.unet import UNet
from robosat.fpn import FPNSegmentation
from robosat.utils import plot
from robosat.config import load_config
from robosat.log import Log
Expand Down Expand Up @@ -68,8 +68,8 @@ def main(args):
os.makedirs(model["common"]["checkpoint"], exist_ok=True)

num_classes = len(dataset["common"]["classes"])
net = UNet(num_classes)
net = DataParallel(net)
net = FPNSegmentation(num_classes)
net = nn.DataParallel(net)
net = net.to(device)

if model["common"]["cuda"]:
Expand Down
1 change: 1 addition & 0 deletions robosat/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, num_classes, num_filters=32, pretrained=True):

Args:
num_classes: number of classes to predict.
num_filters: the number of filters for the decoder block
pretrained: use ImageNet pre-trained backbone feature extractor
"""

Expand Down