From d85728f8ed9cc7373349e23e04ea0970abcd0b86 Mon Sep 17 00:00:00 2001 From: kkannan Date: Tue, 3 Dec 2024 19:06:31 +0000 Subject: [PATCH] Add sum(idx is not None for idx in indices) == 1 check to fix incorrect mapping of multi-indexing in TVM logic --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 2b398497f..d8f657ad5 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2832,7 +2832,7 @@ def index(self, inputs, input_types): return res - elif len(_infer_shape(data)) > 2 : + elif len(_infer_shape(data)) > 2 and sum(idx is not None for idx in indices) == 1: axis = None index_expr = None for i, idx in enumerate(indices):