diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 2b398497f..902a83df6 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2832,7 +2832,7 @@ def index(self, inputs, input_types): return res - elif len(_infer_shape(data)) > 2 : + elif len(_infer_shape(data)) > 2 and len(indices) == 1: axis = None index_expr = None for i, idx in enumerate(indices):