Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX export with dynamic_axes does not work when applying BlurPool #3466

Open
dneup opened this issue Jul 10, 2024 · 1 comment
Open

ONNX export with dynamic_axes does not work when applying BlurPool #3466

dneup opened this issue Jul 10, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@dneup
Copy link

dneup commented Jul 10, 2024

When I try to save a ResNet18 model to ONNX with the export_for_inference function while providing the apply_blurpool surgery algorithm as well as the dynamic axes I get the following error:

/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py:27: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if h % 2 == 0:
/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py:29: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if w % 2 == 0:
/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py:31: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return int(torch.div(h, 2)), int(torch.div(w, 2))
/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py:76: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if (filter.shape[0] == 1) and (channels > 1):
/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py:81: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if h + 2 * padding[0] < filter_h:
/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py:83: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if w + 2 * padding[1] < filter_w:
Traceback (most recent call last):
  File "/home/project/test_model_export.py", line 9, in <module>
    export_for_inference(
  File "/home/project/.venv/lib/python3.10/site-packages/composer/utils/inference.py", line 258, in export_for_inference
    torch.onnx.export(
  File "/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1612, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1138, in _model_to_graph
    graph = _optimize_graph(
  File "/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 677, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1956, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/home/project/.venv/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 306, in wrapper
    return fn(g, *args, **kwargs)
  File "/home/project/.venv/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py", line 2519, in _convolution
    raise errors.SymbolicValueError(
torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of convolution for kernel of unknown shape.  [Caused by the value 'maxs defined in (%maxs : Float(*, 64, *, *, strides=[65536, 1024, 32, 1], requires_grad=1, device=cpu), %146 : Long(*, 64, *, *, device=cpu) = onnx::MaxPool[ceil_mode=0, dilations=[1, 1], kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%142), scope: torchvision.models.resnet.ResNet::/composer.algorithms.blurpool.blurpool_layers.BlurMaxPool2d::maxpool # /home/project/.venv/lib/python3.10/site-packages/torch/nn/functional.py:796:0
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::MaxPool'.] 
    (node defined in /home/project/.venv/lib/python3.10/site-packages/torch/nn/functional.py(796): _max_pool2d
/home/project/.venv/lib/python3.10/site-packages/torch/_jit_internal.py(497): fn
/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py(150): blurmax_pool2d
/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py(201): forward
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1522): _slow_forward
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1541): _call_impl
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1532): _wrapped_call_impl
/home/project/.venv/lib/python3.10/site-packages/torchvision/models/resnet.py(271): _forward_impl
/home/project/.venv/lib/python3.10/site-packages/torchvision/models/resnet.py(285): forward
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1522): _slow_forward
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1541): _call_impl
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1532): _wrapped_call_impl
/home/project/.venv/lib/python3.10/site-packages/torch/jit/_trace.py(129): wrapper
/home/project/.venv/lib/python3.10/site-packages/torch/jit/_trace.py(138): forward
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1541): _call_impl
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1532): _wrapped_call_impl
/home/project/.venv/lib/python3.10/site-packages/torch/jit/_trace.py(1310): _get_trace_graph
/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py(914): _trace_and_get_graph_from_model
/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py(1010): _create_jit_graph
/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py(1134): _model_to_graph
/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py(1612): _export
/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py(516): export
/home/project/.venv/lib/python3.10/site-packages/composer/utils/inference.py(258): export_for_inference
/home/project/test_model_export.py(9): <module>
)

    Inputs:
        #0: 142 defined in (%142 : Float(*, 64, *, *, strides=[65536, 1024, 32, 1], requires_grad=1, device=cpu) = onnx::Relu(%input.4), scope: torchvision.models.resnet.ResNet::/torch.nn.modules.activation.ReLU::relu # /home/project/.venv/lib/python3.10/site-packages/torch/nn/functional.py:1498:0
    )  (type 'Tensor')
    Outputs:
        #0: maxs defined in (%maxs : Float(*, 64, *, *, strides=[65536, 1024, 32, 1], requires_grad=1, device=cpu), %146 : Long(*, 64, *, *, device=cpu) = onnx::MaxPool[ceil_mode=0, dilations=[1, 1], kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%142), scope: torchvision.models.resnet.ResNet::/composer.algorithms.blurpool.blurpool_layers.BlurMaxPool2d::maxpool # /home/project/.venv/lib/python3.10/site-packages/torch/nn/functional.py:796:0
    )  (type 'Tensor')
        #1: 146 defined in (%maxs : Float(*, 64, *, *, strides=[65536, 1024, 32, 1], requires_grad=1, device=cpu), %146 : Long(*, 64, *, *, device=cpu) = onnx::MaxPool[ceil_mode=0, dilations=[1, 1], kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%142), scope: torchvision.models.resnet.ResNet::/composer.algorithms.blurpool.blurpool_layers.BlurMaxPool2d::maxpool # /home/project/.venv/lib/python3.10/site-packages/torch/nn/functional.py:796:0
    )  (type 'Tensor')

** Environment **

torch=2.3.1
torchvision=0.18.1
composer=0.23.5 

** To reproduce

Code snipeet that throws the error:

import torch
from composer.utils import export_for_inference
import torchvision
import composer.functional as cf

model = torchvision.models.resnet18()

export_for_inference(
    model=model,
    save_format="onnx",
    save_path="./model.onnx",
    sample_input=torch.rand(1, 3, 64, 64),
    dynamic_axes={"input": {0: "batch_size", 2: "height", 3: "width"}},
    surgery_algs=[cf.apply_blurpool],
)

The error seems to happen in the code for applying the blurpool operation to the MaxPool2D layer.
The blur_2d function seems to be called with num_channels=-1 which triggers the dynamic control flow not supported by tracing.

The issue also persists when:

  • Using an older version of composer,
  • Using an older version of torch

The issue disappears when:

  • The export is done without dynamic axis as input
  • Using partial(cf.apply_blurpool, replace_maxpools=False)

Also, the TracerWarnings should probably be errors, as the model might not work as expected.

@dneup dneup added the bug Something isn't working label Jul 10, 2024
@mvpatel2000
Copy link
Contributor

@dskhudia any suggestions here? Is it just that it's not possible with dynamic axis?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants