Skip to content

Commit

Permalink
fix: ONNX exportability compatibity test and fix
Browse files Browse the repository at this point in the history
dynamic axes in all dimensions

remove from state_dict

pre-commit  ruff fix
  • Loading branch information
nimiiit committed Nov 14, 2024
1 parent 76cd253 commit 7a414ee
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 4 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ dev = [
"pre-commit",
"pytest",
"pytest-cov",
"onnx",
"sybil", # doctesting
]

Expand Down
10 changes: 6 additions & 4 deletions src/careamics/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,8 @@ def __init__(
self.stride = stride
self.max_pool_size = max_pool_size
self.ceil_mode = ceil_mode
self.kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim)
kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim)
self.register_buffer("kernel", kernel, persistent=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the function.
Expand All @@ -474,19 +475,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.Tensor
Output tensor.
"""
self.kernel = torch.as_tensor(self.kernel, device=x.device, dtype=x.dtype)
kernel = self.kernel.to(dtype=x.dtype)
num_channels = int(x.size(1))
if self.dim == 2:
return _max_blur_pool_by_kernel2d(
x,
self.kernel.repeat((x.size(1), 1, 1, 1)),
kernel.repeat((num_channels, 1, 1, 1)),
self.stride,
self.max_pool_size,
self.ceil_mode,
)
else:
return _max_blur_pool_by_kernel3d(
x,
self.kernel.repeat((x.size(1), 1, 1, 1, 1)),
kernel.repeat((num_channels, 1, 1, 1, 1)),
self.stride,
self.max_pool_size,
self.ceil_mode,
Expand Down
59 changes: 59 additions & 0 deletions tests/lightning/test_lightning_module_onnx_exportability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
import torch
from onnx import checker

from careamics.config import FCNAlgorithmConfig
from careamics.lightning.lightning_module import FCNModule


@pytest.mark.parametrize(
"algorithm, architecture, conv_dim, n2v2, loss, shape",
[
("n2n", "UNet", 2, False, "mae", (16, 16)), # n2n 2D model
("n2n", "UNet", 3, False, "mae", (8, 16, 16)), # n2n 3D model
("n2v", "UNet", 2, False, "n2v", (16, 16)), # n2v 2D model
("n2v", "UNet", 3, False, "n2v", (8, 16, 16)), # n2v 3D model
("n2v", "UNet", 2, True, "n2v", (16, 16)), # n2v2 2D model
("n2v", "UNet", 3, True, "n2v", (8, 16, 16)), # n2v2 3D model
],
)
def test_onnx_export(tmp_path, algorithm, architecture, conv_dim, n2v2, loss, shape):
"""Test model exportability to ONNX."""

algo_config = {
"algorithm": algorithm,
"model": {
"architecture": architecture,
"conv_dims": conv_dim,
"in_channels": 1,
"num_classes": 1,
"depth": 3,
"n2v2": n2v2,
},
"loss": loss,
}
algo_config = FCNAlgorithmConfig(**algo_config)

# instantiate CAREamicsKiln
model = FCNModule(algo_config)
# set model to evaluation mode to avoid batch dimension error
model.model.eval()
# create a sample input of BC(Z)XY
x = torch.rand((1, 1, *shape))

# create dynamic axes from the shape of the x
dynamic_axes = {"input": {}, "output": {}}
for i in range(len(x.shape)):
dynamic_axes["input"][i] = f"dim_{i}"
dynamic_axes["output"][i] = f"dim_{i}"

torch.onnx.export(
model,
x,
f"{tmp_path}/test_model.onnx",
input_names=["input"], # the model's input names
output_names=["output"], # the model's output names
dynamic_axes=dynamic_axes, # variable length axes,
)

checker.check_model(f"{tmp_path}/test_model.onnx")

0 comments on commit 7a414ee

Please sign in to comment.