Skip to content

Commit

Permalink
Merge branch 'main' into akannan/remove_repeat_decompose
Browse files Browse the repository at this point in the history
  • Loading branch information
ashokkumarkannan1 authored Nov 26, 2024
2 parents c6aa189 + 932d1d4 commit 8735691
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2689,7 +2689,18 @@ def rsub(self, inputs, input_types):
def embedding(self, inputs, input_types):
weight = inputs[0]
indices = inputs[1]
return _op.embedding(weight, indices.astype("int32"), axis=0)
# Check the type of indices
indicies_dtype = _infer_type(indices).checked_type.dtype
if indicies_dtype != "int32" and indicies_dtype != "int64":
# we want to cast indices to int32 if they are not already.
# However, If indices are int64, there is no need to cast them to int32
# because forge doesn't support int64 and indices will be cast to int32
# during lowering from torch to forge. Therefore adding cast of int64 to int32
# will result in cast op that casts int32 to int32. This is not only redundant, but also
# exposes a few bugs in tt-mlir https://github.com/tenstorrent/tt-mlir/issues/1215
logger.warning("Casting input indices of embedding op from {} to int32", indicies_dtype)
indices = tvm.relay.cast(indices, "int32")
return _op.embedding(weight, indices, axis=0)

def embedding_bag(self, inputs, input_types):

Expand Down

0 comments on commit 8735691

Please sign in to comment.