From 38d2522943e666d0b86a0f78b06b2b69c06b59cf Mon Sep 17 00:00:00 2001 From: Kamalraj Kannan <157608228+kamalrajkannan78@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:57:39 +0530 Subject: [PATCH] Add sum(idx is not None for idx in indices) == 1 check to fix incorrect mapping of multi-indexing in TVM logic (#51) --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 2b398497f..d8f657ad5 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2832,7 +2832,7 @@ def index(self, inputs, input_types): return res - elif len(_infer_shape(data)) > 2 : + elif len(_infer_shape(data)) > 2 and sum(idx is not None for idx in indices) == 1: axis = None index_expr = None for i, idx in enumerate(indices):