Skip to content

Commit

Permalink
Add support for 4D data shapes with 3D indices in index operation
Browse files Browse the repository at this point in the history
  • Loading branch information
kamalrajkannan78 committed Nov 26, 2024
1 parent 932d1d4 commit 079b38b
Showing 1 changed file with 44 additions and 28 deletions.
72 changes: 44 additions & 28 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2747,34 +2747,50 @@ def index(self, inputs, input_types):
# Cover case when first row is selected as whole in Python style using ':' (e.g. x[:, mask])
# while second is selected using a boolean mask
if indices[0] == None:
# Remove first None argument (represents ':')
indices.pop(0)

assert len(_infer_shape(data)) == 2 and len(indices) == 1, "Currently supportes only 2D tensors with single mask"

indices = indices[0]
if len(_analysis.free_vars(indices)) == 0:
if isinstance(indices, _expr.Constant):
# Extract direct data as numpy array
indices = indices.data.numpy()
else:
# Infer values to get constant data (discards irelevant math which
# will eventually be optimized out)
indices = _infer_value(indices, {}).numpy()
indices = np.transpose(np.argwhere(indices))
indices = indices[0] if len(indices.shape) == 2 and indices.shape[0] == 1 else indices
indices = _expr.const(indices)

# Iterate over each row and extract the selected columns
res = []
for dim in range(_infer_shape(data)[0]):
partial_res = _op.take(data, _expr.const(dim), 0)
partial_res = _op.adv_index([partial_res, indices])
partial_res = _op.expand_dims(partial_res, 0)
res.append(partial_res)
res = _op.concatenate(res, 0)

return res
if len(_infer_shape(data)) == 2 and len(indices) == 1:
# Remove first None argument (represents ':')
indices.pop(0)

indices = indices[0]
if len(_analysis.free_vars(indices)) == 0:
if isinstance(indices, _expr.Constant):
# Extract direct data as numpy array
indices = indices.data.numpy()
else:
# Infer values to get constant data (discards irelevant math which
# will eventually be optimized out)
indices = _infer_value(indices, {}).numpy()
indices = np.transpose(np.argwhere(indices))
indices = indices[0] if len(indices.shape) == 2 and indices.shape[0] == 1 else indices
indices = _expr.const(indices)

# Iterate over each row and extract the selected columns
res = []
for dim in range(_infer_shape(data)[0]):
partial_res = _op.take(data, _expr.const(dim), 0)
partial_res = _op.adv_index([partial_res, indices])
partial_res = _op.expand_dims(partial_res, 0)
res.append(partial_res)
res = _op.concatenate(res, 0)

return res

elif len(_infer_shape(data)) == 4 and len(indices) == 3 :
axis = None
index_expr = None
for i, idx in enumerate(indices):
if idx is not None:
axis = i
index_expr = idx
break

res = _op.take(data, index_expr, axis=axis)
return res

assert not ((len(_infer_shape(data)) == 2 and len(indices) == 1) or
(len(_infer_shape(data)) == 4 and len(indices) == 3)), \
f"Indexing for data shape {data_shape} and indices length {len(indices)} is not supported yet"


# Cover cases when first dim (row) is selected based on boolean mask as indices (e.g. x[mask]),
# and indices aren't dependant on input activation
Expand Down

0 comments on commit 079b38b

Please sign in to comment.