Skip to content

Commit

Permalink
Update the way modify_onnx optimisation runs are conducted in the ONN…
Browse files Browse the repository at this point in the history
…XModifier class
  • Loading branch information
ptoupas committed Jan 14, 2025
1 parent e8bc974 commit e2a7ed7
Showing 1 changed file with 62 additions and 54 deletions.
116 changes: 62 additions & 54 deletions modelconverter/utils/onnx_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple

import numpy as np
import onnx
Expand Down Expand Up @@ -1054,6 +1054,33 @@ def fuse_split_concat_to_conv(self) -> None:

self.optimize_onnx()

def revert_changes(self):
"""Reverts ONNX model to previous state."""
self.onnx_model = self.prev_onnx_model
self.onnx_gs = self.prev_onnx_gs

def apply_optimization_step(
self, step_name: str, optimization_func: Callable
):
"""Applies a single optimization step to the ONNX model.
@param step_name: Name of the optimization step
@type step_name: str
@param optimization_func: Optimization function to apply
@type optimization_func: Callable
"""
logger.debug(f"Attempting: {step_name}...")
try:
optimization_func()
if not self.compare_outputs(from_modelproto=True):
logger.warning(f"Failed: {step_name}, reverting changes...")
self.revert_changes()
except Exception as e:
logger.warning(
f"Failed: {step_name} with error: {e}, reverting changes..."
)
self.revert_changes()

def modify_onnx(self) -> bool:
"""Modify the ONNX model by applying a series of optimizations.
Expand All @@ -1066,65 +1093,46 @@ def modify_onnx(self) -> bool:
)
return False

try:
logger.debug("Substituting Div -> Mul nodes...")
self.substitute_node_by_type(source_node="Div", target_node="Mul")
if not self.compare_outputs(from_modelproto=True):
logger.warning(
"Failed to substitute Div -> Mul nodes, reverting changes..."
)
self.onnx_model = self.prev_onnx_model
self.onnx_gs = self.prev_onnx_gs

logger.debug("Substituting Sub -> Add nodes...")
self.substitute_node_by_type(source_node="Sub", target_node="Add")
if not self.compare_outputs(from_modelproto=True):
logger.warning(
"Failed to substitute Sub -> Add nodes, reverting changes..."
)
self.onnx_model = self.prev_onnx_model
self.onnx_gs = self.prev_onnx_gs

logger.debug(
"Fusing Add and Mul nodes to BatchNormalization nodes and then into Conv nodes..."
)
self.fuse_add_mul_to_bn()
if not self.compare_outputs(from_modelproto=True):
logger.warning(
"Failed to fuse Add and Mul nodes to BatchNormalization nodes, reverting changes..."
)
self.onnx_model = self.prev_onnx_model
self.onnx_gs = self.prev_onnx_gs

logger.debug("Fusing Add and Mul nodes to Conv nodes...")
self.fuse_comb_add_mul_to_conv()
if not self.compare_outputs(from_modelproto=True):
logger.warning(
"Failed to fuse Add and Mul nodes (combined) to Conv nodes, reverting changes..."
)
self.onnx_model = self.prev_onnx_model
self.onnx_gs = self.prev_onnx_gs
self.fuse_single_add_mul_to_conv()
if not self.compare_outputs(from_modelproto=True):
logger.warning(
"Failed to fuse Add and Mul nodes (single) to Conv nodes, reverting changes..."
)
self.onnx_model = self.prev_onnx_model
self.onnx_gs = self.prev_onnx_gs
optimization_steps = [
(
"Substitute Div -> Mul nodes",
lambda: self.substitute_node_by_type(
source_node="Div", target_node="Mul"
),
),
(
"Substitute Sub -> Add nodes",
lambda: self.substitute_node_by_type(
source_node="Sub", target_node="Add"
),
),
(
"Fuse Add and Mul nodes to BatchNormalization nodes",
self.fuse_add_mul_to_bn,
),
(
"Fuse Add and Mul nodes to Conv nodes (combined)",
self.fuse_comb_add_mul_to_conv,
),
(
"Fuse Add and Mul nodes to Conv nodes (single)",
self.fuse_single_add_mul_to_conv,
),
(
"Fuse Split and Concat nodes to Conv nodes",
self.fuse_split_concat_to_conv,
),
]

logger.debug("Fusing Split and Concat nodes to Conv nodes...")
self.fuse_split_concat_to_conv()
if not self.compare_outputs(from_modelproto=True):
logger.warning(
"Failed to fuse Split and Concat nodes to Conv nodes, reverting changes..."
)
self.onnx_model = self.prev_onnx_model
self.onnx_gs = self.prev_onnx_gs
for step_name, optimization_func in optimization_steps:
self.apply_optimization_step(step_name, optimization_func)

try:
self.export_onnx()
except Exception as e:
logger.error(f"Failed to modify the ONNX model: {e}")
return False

return True

def compare_outputs(self, from_modelproto: bool = False) -> bool:
Expand Down

0 comments on commit e2a7ed7

Please sign in to comment.