Skip to content

Commit

Permalink
Add TVM decomposition for masked_scatter op
Browse files Browse the repository at this point in the history
  • Loading branch information
ashokkumarkannan1 committed Feb 6, 2025
1 parent 9b18def commit 2022ce9
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 24 deletions.
22 changes: 0 additions & 22 deletions forge/forge/op/eval/forge/tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,28 +1081,6 @@ def decompose(type, attr, dc, inputs):
dc.fuse(result)
return

if type == "adv_index":
dim = attr[0]
in0_shape = inputs[0].shape
in1_shape = inputs[1].shape
if len(in0_shape) == 1 or in0_shape[dim] == 1:
result = dc.op(Nop.create(), [inputs[0]])
dc.fuse(result)
return
if dim == 0 and len(in1_shape) <= 2:
# Consider the case adv_index(X,Y) where
# X: (A, B), Y: (1, C) or (C,) and A != 1
if len(in0_shape) == 2:
# embedding op expects indices tensor as first argument and weight/embedding_table as second argument
# but the adv_index provides the reference tensor as first argument and indices tensor as second argument
# so swaping the operands.
result = dc.op(
"embedding",
(inputs[1], inputs[0]),
)
dc.fuse(result)
return

if type == "pad":
if all([x == 0 for x in attr[0:-2]]):
# Pad size is 0
Expand Down
2 changes: 1 addition & 1 deletion forge/test/mlir/operators/indexing/test_scatter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
),
],
)
@pytest.mark.xfail(reason="NotImplementedError: The following operators are not implemented: ['aten::masked_scatter']")

def test_masked_scatter(input_tensor, mask, source):
class MaskedScatterModule(torch.nn.Module):
def __init__(self, mask, source):
Expand Down
2 changes: 1 addition & 1 deletion third_party/tvm

0 comments on commit 2022ce9

Please sign in to comment.