Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert to onnx #54

Open
ghost opened this issue Jul 21, 2023 · 1 comment
Open

Convert to onnx #54

ghost opened this issue Jul 21, 2023 · 1 comment

Comments

@ghost
Copy link

ghost commented Jul 21, 2023

Hello, how to convert pretrained model to onnx?

@99991
Copy link

99991 commented Sep 5, 2024

Here is an example which exports and loads FBA-Net using ONNX. Before running, download FBA.pth and place it in the same directory. I have only tested this with a single image. No guarantees that this works with images of different sizes.

It would also be possible to replace the call to OpenCV's distance field function with pure PyTorch to get rid of the OpenCV dependency, but that is a bit of work, so I did not do it for now.

import numpy as np
import os, cv2, math, urllib.request
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyGroupNorm(nn.Module):
    def __init__(self, num_groups, num_channels, eps=1e-5):
        super().__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))

    def forward(self, x):
        n, c, h, w = x.shape
        g = self.num_groups
        x = x.view(n, g, -1)
        var = x.var(dim=-1, keepdim=True, correction=0)
        x = (x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(var + self.eps)
        x = x.view(n, c, -1) * self.weight.view(1, -1, 1) + self.bias.view(1, -1, 1)
        return x.view(n, c, h, w)

class MyAdaptiveAvgPool2d(nn.Module):
    def __init__(self, output_size):
        super().__init__()
        self.output_size = output_size

    def forward(self, batch):
        size = self.output_size
        n, c, h, w = batch.shape
        output = torch.zeros((n, c, size, size), device=batch.device)
        for y in range(size):
            for x in range(size):
                x0 = math.floor(x * w / size)
                y0 = math.floor(y * h / size)
                x1 = math.ceil((x + 1) * w / size)
                y1 = math.ceil((y + 1) * h / size)
                output[:, :, y, x] = batch[:, :, y0:y1, x0:x1].mean(dim=(2, 3))
        return output

def norm(dim):
    return nn.GroupNorm(32, dim)
    # use this if your ONNX implementation does not support GroupNorm
    #return MyGroupNorm(32, dim)

class Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__(in_channels, out_channels, **kwargs)

    def normalize_weight(self):
        weight = F.batch_norm(
            self.weight.view(1, self.out_channels, -1),
            None,
            None,
            training=True,
            momentum=0.0,
        ).reshape_as(self.weight)
        self.weight.data = weight

    def forward(self, x):
        if self.training:
            self.normalize_weight()

        return super().forward(x)

    def train(self, mode: bool = True):
        super().train(mode=mode)
        self.normalize_weight()

def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1, bias=False):
    return Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        dilation=dilation,
        padding=padding,
        bias=bias,
    )


def conv1x1(in_planes, out_planes, stride=1, bias=False):
    return Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)


def dt(a):
    return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)


def trimap_transform(trimap, L=320):
    clicks = []
    for k in range(2):
        dt_mask = -dt(1 - trimap[:, :, k]) ** 2
        clicks.append(np.exp(dt_mask / (2 * ((0.02 * L) ** 2))))
        clicks.append(np.exp(dt_mask / (2 * ((0.08 * L) ** 2))))
        clicks.append(np.exp(dt_mask / (2 * ((0.16 * L) ** 2))))
    clicks = np.array(clicks)
    return clicks


def normalise_image(image):
    # Warning: Values are for RGB, but OpenCV loads images as BGR
    mean = torch.tensor([0.485, 0.456, 0.406], device=image.device).reshape(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=image.device).reshape(1, 3, 1, 1)
    return (image - mean) / std

def pyramid_pooling_module(scale):
    return nn.Sequential(
        MyAdaptiveAvgPool2d(scale),
        conv1x1(2048, 256, bias=True),
        norm(256),
        nn.LeakyReLU(),
    )


def resize(x, **kwargs):
    return nn.functional.interpolate(x, mode="bilinear", align_corners=False, **kwargs)


class Bottleneck(nn.Module):
    def __init__(
        self,
        inplanes,
        planes,
        stride=1,
        padding=1,
        dilation=1,
        expansion=4,
        downsample=None,
    ):
        super().__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = norm(planes)
        self.conv2 = conv3x3(planes, planes, stride, padding, dilation)
        self.bn2 = norm(planes)
        self.conv3 = conv1x1(planes, planes * expansion)
        self.bn3 = norm(planes * expansion)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        out = x
        out = self.conv1(out)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            x = self.downsample(x)
        out += x
        out = F.relu(out)
        return out


class ResnetDilated(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = Conv2d(
            in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
        )
        self.bn1 = norm(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = nn.Sequential(
            Bottleneck(64, 64, stride=1, padding=1, dilation=1, downsample=nn.Sequential(conv1x1(64, 256), norm(256))),
            Bottleneck(256, 64, stride=1, padding=1, dilation=1),
            Bottleneck(256, 64, stride=1, padding=1, dilation=1),
        )
        self.layer2 = nn.Sequential(
            Bottleneck(256, 128, stride=2, padding=1, dilation=1, downsample=nn.Sequential(conv1x1(256, 512, stride=2), norm(512)),),
            Bottleneck(512, 128, stride=1, padding=1, dilation=1),
            Bottleneck(512, 128, stride=1, padding=1, dilation=1),
            Bottleneck(512, 128, stride=1, padding=1, dilation=1),
        )
        self.layer3 = nn.Sequential(
            Bottleneck( 512, 256, stride=1, padding=1, dilation=1, downsample=nn.Sequential(conv1x1(512, 1024), norm(1024))),
            Bottleneck(1024, 256, stride=1, padding=2, dilation=2),
            Bottleneck(1024, 256, stride=1, padding=2, dilation=2),
            Bottleneck(1024, 256, stride=1, padding=2, dilation=2),
            Bottleneck(1024, 256, stride=1, padding=2, dilation=2),
            Bottleneck(1024, 256, stride=1, padding=2, dilation=2),
        )
        self.layer4 = nn.Sequential(
            Bottleneck(1024, 512, stride=1, padding=2, dilation=2, downsample=nn.Sequential(conv1x1(1024, 2048), norm(2048))),
            Bottleneck(2048, 512, stride=1, padding=4, dilation=4),
            Bottleneck(2048, 512, stride=1, padding=4, dilation=4),
        )

    def forward(self, x):
        conv_out = [x]
        x = F.relu(self.bn1(self.conv1(x)))
        conv_out.append(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        conv_out.append(x)
        x = self.layer2(x)
        conv_out.append(x)
        x = self.layer3(x)
        conv_out.append(x)
        x = self.layer4(x)
        conv_out.append(x)
        return conv_out


class fba_decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.ppm = nn.ModuleList(
            [
                pyramid_pooling_module(scale=1),
                pyramid_pooling_module(scale=2),
                pyramid_pooling_module(scale=3),
                pyramid_pooling_module(scale=6),
            ]
        )
        self.conv_up1 = nn.Sequential(
            conv3x3(2048 + len(self.ppm) * 256, 256, bias=True),
            norm(256),
            nn.LeakyReLU(),
            conv3x3(256, 256, bias=True),
            norm(256),
            nn.LeakyReLU(),
        )
        self.conv_up2 = nn.Sequential(
            conv3x3(256 + 256, 256, bias=True), norm(256), nn.LeakyReLU()
        )
        self.conv_up3 = nn.Sequential(
            conv3x3(256 + 64, 64, bias=True), norm(64), nn.LeakyReLU()
        )
        self.conv_up4 = nn.Sequential(
            nn.Conv2d(64 + 3 + 3 + 2, 32, 3, 1, 1, bias=True),
            nn.LeakyReLU(),
            nn.Conv2d(32, 16, 3, 1, 1, bias=True),
            nn.LeakyReLU(),
            nn.Conv2d(16, 7, 1, bias=True),
        )

    def forward(self, conv_out, img, two_chan_trimap):
        conv5 = conv_out[-1]
        ppm_out = [conv5]
        for ppm in self.ppm:
            small_conv5 = ppm(conv5)
            large_conv5 = resize(small_conv5, size=conv5.shape[2:])
            ppm_out.append(large_conv5)
        x = torch.cat(ppm_out, 1)
        x = self.conv_up1(x)
        x = resize(x, scale_factor=2)
        x = torch.cat((x, conv_out[-4]), 1)
        x = self.conv_up2(x)
        x = resize(x, scale_factor=2)
        x = torch.cat((x, conv_out[-5]), 1)
        x = self.conv_up3(x)
        x = resize(x, scale_factor=2)
        x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1)
        x = self.conv_up4(x)
        alpha = torch.clamp(x[:, 0][:, None], 0, 1)
        F = torch.sigmoid(x[:, 1:4])
        B = torch.sigmoid(x[:, 4:7])
        F = alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B
        B = (1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F
        F = torch.clamp(F, 0, 1)
        B = torch.clamp(B, 0, 1)
        la = 0.1
        alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (
            torch.sum((F - B) * (F - B), 1, keepdim=True) + la
        )
        alpha = torch.clamp(alpha, 0, 1)
        return torch.cat((alpha, F, B), 1)


class MattingModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ResnetDilated(in_channels=11)
        self.decoder = fba_decoder()

    def forward(self, image, two_chan_trimap, trimap_transformed):
        image_n = normalise_image(image)
        resnet_input = torch.cat((image_n, trimap_transformed, two_chan_trimap), 1)
        conv_out = self.encoder(resnet_input)
        return self.decoder(conv_out, image, two_chan_trimap)

def test():
    urls = """
    https://raw.githubusercontent.com/MarcoForte/FBA_Matting/master/examples/images/troll.png
    https://raw.githubusercontent.com/MarcoForte/FBA_Matting/master/examples/predictions/troll_alpha.png
    https://raw.githubusercontent.com/MarcoForte/FBA_Matting/master/examples/trimaps/troll.png
    """
    for url in urls.strip().split():
        _, filename = url.split("/master/")
        if not os.path.isfile(filename):
            print("Downloading", url)
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            with urllib.request.urlopen(url) as r:
                data = r.read()
            with open(filename, "wb") as f:
                f.write(data)

    if not os.path.isfile("FBA.pth"):
        print("Download the model file from https://github.com/MarcoForte/FBA_Matting?tab=readme-ov-file#models and save it as FBA.pth in the current directory")
        return

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = MattingModule()
    model.load_state_dict(torch.load("FBA.pth", map_location=device), strict=True)
    model.to(device)
    model.train(False)

    image_np = cv2.imread("examples/images/troll.png")[:, :, ::-1] / 255.0
    trimap_np = cv2.imread("examples/trimaps/troll.png", cv2.IMREAD_GRAYSCALE) / 255.0
    trimap_np = np.stack([trimap_np == 0, trimap_np == 1], axis=2).astype(np.float32)
    h, w = trimap_np.shape[:2]

    h8 = int(np.ceil(h / 8) * 8)
    w8 = int(np.ceil(w / 8) * 8)
    image_scale_np = cv2.resize(image_np, (w8, h8), interpolation=cv2.INTER_LANCZOS4)
    trimap_scale_np = cv2.resize(trimap_np, (w8, h8), interpolation=cv2.INTER_LANCZOS4)

    with torch.no_grad():
        image = torch.from_numpy(image_scale_np).permute(2, 0, 1)[None, :, :, :].float().to(device)
        trimap = torch.from_numpy(trimap_scale_np).permute(2, 0, 1)[None, :, :, :].float().to(device)
        trimap_transformed = torch.from_numpy(trimap_transform(trimap_scale_np))[None, :, :, :].float().to(device)

        if 0:
            # using PyTorch
            output = model(image, trimap, trimap_transformed)
            output = output[0].cpu().numpy().transpose(1, 2, 0)
        else:
            # using onnx
            args = (image, trimap, trimap_transformed)

            torch.onnx.export(model, args, "model.onnx", verbose=True)

            import onnxruntime

            sess = onnxruntime.InferenceSession("model.onnx")

            input_feed = {inp.name: arg.detach().cpu().numpy()
                for inp, arg in zip(sess.get_inputs(), args)}

            output = sess.run(None, input_feed)[0]

            output = output[0].transpose(1, 2, 0)

    output = cv2.resize(output, (w, h), cv2.INTER_LANCZOS4)

    alpha = output[:, :, 0]
    fg = output[:, :, 1:4]
    bg = output[:, :, 4:7]

    alpha[trimap_np[:, :, 0] == 1] = 0
    alpha[trimap_np[:, :, 1] == 1] = 1
    fg[alpha == 1] = image_np[alpha == 1]
    bg[alpha == 0] = image_np[alpha == 0]

    alpha_expected = cv2.imread("examples/predictions/troll_alpha.png", cv2.IMREAD_GRAYSCALE) / 255.0

    mse = np.mean(np.square(alpha - alpha_expected))

    print(f"MSE: {mse:.20f}")

    assert mse < 1e-6, f"Error too large. I blame the developers of some dependency."

    print("Test passed :)")

if __name__ == "__main__":
    test()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant