diff --git a/python/tvm/relay/op/contrib/forge/forge_passes.py b/python/tvm/relay/op/contrib/forge/forge_passes.py index 45f811aa5..630d81fb4 100644 --- a/python/tvm/relay/op/contrib/forge/forge_passes.py +++ b/python/tvm/relay/op/contrib/forge/forge_passes.py @@ -1883,6 +1883,9 @@ def __init__(self): def callback(self, pre, post, node_map): act = node_map[self.act][0] axis = post.attrs.axis + # Skip removal of squeeze which contain dynamic shapes + if any([isinstance(dim, tvm.tir.expr.Any) for dim in pre.checked_type.shape]): + return post input_shape = [int(dim) for dim in pre.args[0].checked_type.shape] adjusted_axes = [(ax - len(input_shape)) if ax >= 0 else ax for ax in axis] assert all(ax < 0 for ax in adjusted_axes), "Invalid squeeze dimension: all axes must be negative."