Skip to content

Commit

Permalink
Enhance robustness of index operation for 3D and 4D tensors across va…
Browse files Browse the repository at this point in the history
…rious dimensions and index values. (#50)
  • Loading branch information
kamalrajkannan78 authored Nov 27, 2024
1 parent 62a94d1 commit 24e1c49
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2746,12 +2746,10 @@ def index(self, inputs, input_types):
indices = inputs[1]
# Cover case when first row is selected as whole in Python style using ':' (e.g. x[:, mask])
# while second is selected using a boolean mask
if indices[0] == None:
if len(_infer_shape(data)) == 2 and len(indices) == 1 and indices[0] == None:
# Remove first None argument (represents ':')
indices.pop(0)

assert len(_infer_shape(data)) == 2 and len(indices) == 1, "Currently supportes only 2D tensors with single mask"

indices = indices[0]
if len(_analysis.free_vars(indices)) == 0:
if isinstance(indices, _expr.Constant):
Expand Down Expand Up @@ -2833,6 +2831,18 @@ def index(self, inputs, input_types):
res = _op.adv_index([data, indices])

return res

elif len(_infer_shape(data)) > 2 :
axis = None
index_expr = None
for i, idx in enumerate(indices):
if idx is not None:
axis = i
index_expr = idx
break

res = _op.take(data, index_expr, axis=axis)
return res

return _op.adv_index([data] + indices)

Expand Down

0 comments on commit 24e1c49

Please sign in to comment.