Skip to content

Commit

Permalink
fix: ONNX exportability compatibity test and fix (#275)
Browse files Browse the repository at this point in the history
fix: Fix for ONNX export of Maxblurpool layer and performance
optimization by registering kernel as a buffer so that it doesn't need
to be copied to the GPU over and over again.
### Description
- **What**: Converting the pretrained models to ONNX format gives error
in the Maxpool layer used in the N2V2 architecture.This is mainly
because the convolution kernel is dynamically expanded to a size
matching the number of channels in the input in the Maxblurpool layer.
But the number of channels should be constant within the model.
- **Why**: Users can convert the pytorch models to ONNX for inference in
thier platforms
- **How**:  
-- instead of using the symbolic variable x.size(1), explicitly cast it
to an integer and make it a constant.
-- make the kernel as a buffer to avoid the copying to GPU overhead.
-- add tests for ONNX exportability
### Changes Made

- **Added**:  
-- onnx as a test dependency in pyproject.toml
--  'test_lightning_module_onnx_exportability.py'
- **Modified**: Maxblurpool module in  'layers.py'

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

Co-authored-by: Joran Deschamps <[email protected]>
  • Loading branch information
nimiiit and jdeschamps authored Nov 26, 2024
1 parent 0e0bc28 commit f0fcc89
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 4 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ dev = [
"pre-commit",
"pytest",
"pytest-cov",
"onnx",
"sybil", # doctesting
]

Expand Down
10 changes: 6 additions & 4 deletions src/careamics/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,8 @@ def __init__(
self.stride = stride
self.max_pool_size = max_pool_size
self.ceil_mode = ceil_mode
self.kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim)
kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim)
self.register_buffer("kernel", kernel, persistent=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the function.
Expand All @@ -474,19 +475,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.Tensor
Output tensor.
"""
self.kernel = torch.as_tensor(self.kernel, device=x.device, dtype=x.dtype)
kernel = self.kernel.to(dtype=x.dtype)
num_channels = int(x.size(1))
if self.dim == 2:
return _max_blur_pool_by_kernel2d(
x,
self.kernel.repeat((x.size(1), 1, 1, 1)),
kernel.repeat((num_channels, 1, 1, 1)),
self.stride,
self.max_pool_size,
self.ceil_mode,
)
else:
return _max_blur_pool_by_kernel3d(
x,
self.kernel.repeat((x.size(1), 1, 1, 1, 1)),
kernel.repeat((num_channels, 1, 1, 1, 1)),
self.stride,
self.max_pool_size,
self.ceil_mode,
Expand Down
59 changes: 59 additions & 0 deletions tests/lightning/test_lightning_module_onnx_exportability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
import torch
from onnx import checker

from careamics.config import FCNAlgorithmConfig
from careamics.lightning.lightning_module import FCNModule


@pytest.mark.parametrize(
"algorithm, architecture, conv_dim, n2v2, loss, shape",
[
("n2n", "UNet", 2, False, "mae", (16, 16)), # n2n 2D model
("n2n", "UNet", 3, False, "mae", (8, 16, 16)), # n2n 3D model
("n2v", "UNet", 2, False, "n2v", (16, 16)), # n2v 2D model
("n2v", "UNet", 3, False, "n2v", (8, 16, 16)), # n2v 3D model
("n2v", "UNet", 2, True, "n2v", (16, 16)), # n2v2 2D model
("n2v", "UNet", 3, True, "n2v", (8, 16, 16)), # n2v2 3D model
],
)
def test_onnx_export(tmp_path, algorithm, architecture, conv_dim, n2v2, loss, shape):
"""Test model exportability to ONNX."""

algo_config = {
"algorithm": algorithm,
"model": {
"architecture": architecture,
"conv_dims": conv_dim,
"in_channels": 1,
"num_classes": 1,
"depth": 3,
"n2v2": n2v2,
},
"loss": loss,
}
algo_config = FCNAlgorithmConfig(**algo_config)

# instantiate CAREamicsKiln
model = FCNModule(algo_config)
# set model to evaluation mode to avoid batch dimension error
model.model.eval()
# create a sample input of BC(Z)XY
x = torch.rand((1, 1, *shape))

# create dynamic axes from the shape of the x
dynamic_axes = {"input": {}, "output": {}}
for i in range(len(x.shape)):
dynamic_axes["input"][i] = f"dim_{i}"
dynamic_axes["output"][i] = f"dim_{i}"

torch.onnx.export(
model,
x,
f"{tmp_path}/test_model.onnx",
input_names=["input"], # the model's input names
output_names=["output"], # the model's output names
dynamic_axes=dynamic_axes, # variable length axes,
)

checker.check_model(f"{tmp_path}/test_model.onnx")

0 comments on commit f0fcc89

Please sign in to comment.