-
Notifications
You must be signed in to change notification settings - Fork 95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert to onnx #54
Comments
Here is an example which exports and loads FBA-Net using ONNX. Before running, download It would also be possible to replace the call to OpenCV's distance field function with pure PyTorch to get rid of the OpenCV dependency, but that is a bit of work, so I did not do it for now. import numpy as np
import os, cv2, math, urllib.request
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyGroupNorm(nn.Module):
def __init__(self, num_groups, num_channels, eps=1e-5):
super().__init__()
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.weight = nn.Parameter(torch.zeros(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
def forward(self, x):
n, c, h, w = x.shape
g = self.num_groups
x = x.view(n, g, -1)
var = x.var(dim=-1, keepdim=True, correction=0)
x = (x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(var + self.eps)
x = x.view(n, c, -1) * self.weight.view(1, -1, 1) + self.bias.view(1, -1, 1)
return x.view(n, c, h, w)
class MyAdaptiveAvgPool2d(nn.Module):
def __init__(self, output_size):
super().__init__()
self.output_size = output_size
def forward(self, batch):
size = self.output_size
n, c, h, w = batch.shape
output = torch.zeros((n, c, size, size), device=batch.device)
for y in range(size):
for x in range(size):
x0 = math.floor(x * w / size)
y0 = math.floor(y * h / size)
x1 = math.ceil((x + 1) * w / size)
y1 = math.ceil((y + 1) * h / size)
output[:, :, y, x] = batch[:, :, y0:y1, x0:x1].mean(dim=(2, 3))
return output
def norm(dim):
return nn.GroupNorm(32, dim)
# use this if your ONNX implementation does not support GroupNorm
#return MyGroupNorm(32, dim)
class Conv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__(in_channels, out_channels, **kwargs)
def normalize_weight(self):
weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1),
None,
None,
training=True,
momentum=0.0,
).reshape_as(self.weight)
self.weight.data = weight
def forward(self, x):
if self.training:
self.normalize_weight()
return super().forward(x)
def train(self, mode: bool = True):
super().train(mode=mode)
self.normalize_weight()
def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1, bias=False):
return Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
dilation=dilation,
padding=padding,
bias=bias,
)
def conv1x1(in_planes, out_planes, stride=1, bias=False):
return Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)
def dt(a):
return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)
def trimap_transform(trimap, L=320):
clicks = []
for k in range(2):
dt_mask = -dt(1 - trimap[:, :, k]) ** 2
clicks.append(np.exp(dt_mask / (2 * ((0.02 * L) ** 2))))
clicks.append(np.exp(dt_mask / (2 * ((0.08 * L) ** 2))))
clicks.append(np.exp(dt_mask / (2 * ((0.16 * L) ** 2))))
clicks = np.array(clicks)
return clicks
def normalise_image(image):
# Warning: Values are for RGB, but OpenCV loads images as BGR
mean = torch.tensor([0.485, 0.456, 0.406], device=image.device).reshape(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=image.device).reshape(1, 3, 1, 1)
return (image - mean) / std
def pyramid_pooling_module(scale):
return nn.Sequential(
MyAdaptiveAvgPool2d(scale),
conv1x1(2048, 256, bias=True),
norm(256),
nn.LeakyReLU(),
)
def resize(x, **kwargs):
return nn.functional.interpolate(x, mode="bilinear", align_corners=False, **kwargs)
class Bottleneck(nn.Module):
def __init__(
self,
inplanes,
planes,
stride=1,
padding=1,
dilation=1,
expansion=4,
downsample=None,
):
super().__init__()
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = norm(planes)
self.conv2 = conv3x3(planes, planes, stride, padding, dilation)
self.bn2 = norm(planes)
self.conv3 = conv1x1(planes, planes * expansion)
self.bn3 = norm(planes * expansion)
self.downsample = downsample
self.stride = stride
def forward(self, x):
out = x
out = self.conv1(out)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = F.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
x = self.downsample(x)
out += x
out = F.relu(out)
return out
class ResnetDilated(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv1 = Conv2d(
in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = norm(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = nn.Sequential(
Bottleneck(64, 64, stride=1, padding=1, dilation=1, downsample=nn.Sequential(conv1x1(64, 256), norm(256))),
Bottleneck(256, 64, stride=1, padding=1, dilation=1),
Bottleneck(256, 64, stride=1, padding=1, dilation=1),
)
self.layer2 = nn.Sequential(
Bottleneck(256, 128, stride=2, padding=1, dilation=1, downsample=nn.Sequential(conv1x1(256, 512, stride=2), norm(512)),),
Bottleneck(512, 128, stride=1, padding=1, dilation=1),
Bottleneck(512, 128, stride=1, padding=1, dilation=1),
Bottleneck(512, 128, stride=1, padding=1, dilation=1),
)
self.layer3 = nn.Sequential(
Bottleneck( 512, 256, stride=1, padding=1, dilation=1, downsample=nn.Sequential(conv1x1(512, 1024), norm(1024))),
Bottleneck(1024, 256, stride=1, padding=2, dilation=2),
Bottleneck(1024, 256, stride=1, padding=2, dilation=2),
Bottleneck(1024, 256, stride=1, padding=2, dilation=2),
Bottleneck(1024, 256, stride=1, padding=2, dilation=2),
Bottleneck(1024, 256, stride=1, padding=2, dilation=2),
)
self.layer4 = nn.Sequential(
Bottleneck(1024, 512, stride=1, padding=2, dilation=2, downsample=nn.Sequential(conv1x1(1024, 2048), norm(2048))),
Bottleneck(2048, 512, stride=1, padding=4, dilation=4),
Bottleneck(2048, 512, stride=1, padding=4, dilation=4),
)
def forward(self, x):
conv_out = [x]
x = F.relu(self.bn1(self.conv1(x)))
conv_out.append(x)
x = self.maxpool(x)
x = self.layer1(x)
conv_out.append(x)
x = self.layer2(x)
conv_out.append(x)
x = self.layer3(x)
conv_out.append(x)
x = self.layer4(x)
conv_out.append(x)
return conv_out
class fba_decoder(nn.Module):
def __init__(self):
super().__init__()
self.ppm = nn.ModuleList(
[
pyramid_pooling_module(scale=1),
pyramid_pooling_module(scale=2),
pyramid_pooling_module(scale=3),
pyramid_pooling_module(scale=6),
]
)
self.conv_up1 = nn.Sequential(
conv3x3(2048 + len(self.ppm) * 256, 256, bias=True),
norm(256),
nn.LeakyReLU(),
conv3x3(256, 256, bias=True),
norm(256),
nn.LeakyReLU(),
)
self.conv_up2 = nn.Sequential(
conv3x3(256 + 256, 256, bias=True), norm(256), nn.LeakyReLU()
)
self.conv_up3 = nn.Sequential(
conv3x3(256 + 64, 64, bias=True), norm(64), nn.LeakyReLU()
)
self.conv_up4 = nn.Sequential(
nn.Conv2d(64 + 3 + 3 + 2, 32, 3, 1, 1, bias=True),
nn.LeakyReLU(),
nn.Conv2d(32, 16, 3, 1, 1, bias=True),
nn.LeakyReLU(),
nn.Conv2d(16, 7, 1, bias=True),
)
def forward(self, conv_out, img, two_chan_trimap):
conv5 = conv_out[-1]
ppm_out = [conv5]
for ppm in self.ppm:
small_conv5 = ppm(conv5)
large_conv5 = resize(small_conv5, size=conv5.shape[2:])
ppm_out.append(large_conv5)
x = torch.cat(ppm_out, 1)
x = self.conv_up1(x)
x = resize(x, scale_factor=2)
x = torch.cat((x, conv_out[-4]), 1)
x = self.conv_up2(x)
x = resize(x, scale_factor=2)
x = torch.cat((x, conv_out[-5]), 1)
x = self.conv_up3(x)
x = resize(x, scale_factor=2)
x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1)
x = self.conv_up4(x)
alpha = torch.clamp(x[:, 0][:, None], 0, 1)
F = torch.sigmoid(x[:, 1:4])
B = torch.sigmoid(x[:, 4:7])
F = alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B
B = (1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F
F = torch.clamp(F, 0, 1)
B = torch.clamp(B, 0, 1)
la = 0.1
alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (
torch.sum((F - B) * (F - B), 1, keepdim=True) + la
)
alpha = torch.clamp(alpha, 0, 1)
return torch.cat((alpha, F, B), 1)
class MattingModule(nn.Module):
def __init__(self):
super().__init__()
self.encoder = ResnetDilated(in_channels=11)
self.decoder = fba_decoder()
def forward(self, image, two_chan_trimap, trimap_transformed):
image_n = normalise_image(image)
resnet_input = torch.cat((image_n, trimap_transformed, two_chan_trimap), 1)
conv_out = self.encoder(resnet_input)
return self.decoder(conv_out, image, two_chan_trimap)
def test():
urls = """
https://raw.githubusercontent.com/MarcoForte/FBA_Matting/master/examples/images/troll.png
https://raw.githubusercontent.com/MarcoForte/FBA_Matting/master/examples/predictions/troll_alpha.png
https://raw.githubusercontent.com/MarcoForte/FBA_Matting/master/examples/trimaps/troll.png
"""
for url in urls.strip().split():
_, filename = url.split("/master/")
if not os.path.isfile(filename):
print("Downloading", url)
os.makedirs(os.path.dirname(filename), exist_ok=True)
with urllib.request.urlopen(url) as r:
data = r.read()
with open(filename, "wb") as f:
f.write(data)
if not os.path.isfile("FBA.pth"):
print("Download the model file from https://github.com/MarcoForte/FBA_Matting?tab=readme-ov-file#models and save it as FBA.pth in the current directory")
return
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MattingModule()
model.load_state_dict(torch.load("FBA.pth", map_location=device), strict=True)
model.to(device)
model.train(False)
image_np = cv2.imread("examples/images/troll.png")[:, :, ::-1] / 255.0
trimap_np = cv2.imread("examples/trimaps/troll.png", cv2.IMREAD_GRAYSCALE) / 255.0
trimap_np = np.stack([trimap_np == 0, trimap_np == 1], axis=2).astype(np.float32)
h, w = trimap_np.shape[:2]
h8 = int(np.ceil(h / 8) * 8)
w8 = int(np.ceil(w / 8) * 8)
image_scale_np = cv2.resize(image_np, (w8, h8), interpolation=cv2.INTER_LANCZOS4)
trimap_scale_np = cv2.resize(trimap_np, (w8, h8), interpolation=cv2.INTER_LANCZOS4)
with torch.no_grad():
image = torch.from_numpy(image_scale_np).permute(2, 0, 1)[None, :, :, :].float().to(device)
trimap = torch.from_numpy(trimap_scale_np).permute(2, 0, 1)[None, :, :, :].float().to(device)
trimap_transformed = torch.from_numpy(trimap_transform(trimap_scale_np))[None, :, :, :].float().to(device)
if 0:
# using PyTorch
output = model(image, trimap, trimap_transformed)
output = output[0].cpu().numpy().transpose(1, 2, 0)
else:
# using onnx
args = (image, trimap, trimap_transformed)
torch.onnx.export(model, args, "model.onnx", verbose=True)
import onnxruntime
sess = onnxruntime.InferenceSession("model.onnx")
input_feed = {inp.name: arg.detach().cpu().numpy()
for inp, arg in zip(sess.get_inputs(), args)}
output = sess.run(None, input_feed)[0]
output = output[0].transpose(1, 2, 0)
output = cv2.resize(output, (w, h), cv2.INTER_LANCZOS4)
alpha = output[:, :, 0]
fg = output[:, :, 1:4]
bg = output[:, :, 4:7]
alpha[trimap_np[:, :, 0] == 1] = 0
alpha[trimap_np[:, :, 1] == 1] = 1
fg[alpha == 1] = image_np[alpha == 1]
bg[alpha == 0] = image_np[alpha == 0]
alpha_expected = cv2.imread("examples/predictions/troll_alpha.png", cv2.IMREAD_GRAYSCALE) / 255.0
mse = np.mean(np.square(alpha - alpha_expected))
print(f"MSE: {mse:.20f}")
assert mse < 1e-6, f"Error too large. I blame the developers of some dependency."
print("Test passed :)")
if __name__ == "__main__":
test() |
Hello, how to convert pretrained model to onnx?
The text was updated successfully, but these errors were encountered: