Skip to content

Commit

Permalink
Add sum(idx is not None for idx in indices) == 1 check to fix incorre…
Browse files Browse the repository at this point in the history
…ct mapping of multi-indexing in TVM logic (#51)
  • Loading branch information
kamalrajkannan78 authored Dec 5, 2024
1 parent 24e1c49 commit 38d2522
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2832,7 +2832,7 @@ def index(self, inputs, input_types):

return res

elif len(_infer_shape(data)) > 2 :
elif len(_infer_shape(data)) > 2 and sum(idx is not None for idx in indices) == 1:
axis = None
index_expr = None
for i, idx in enumerate(indices):
Expand Down

0 comments on commit 38d2522

Please sign in to comment.